native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self},
  4    io::BufReader,
  5    stream::{SelectAll, StreamExt},
  6    AsyncBufReadExt as _, SinkExt as _,
  7};
  8use gpui::{EntityId, Task, View, WindowContext};
  9use jupyter_protocol::{JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply};
 10use project::Fs;
 11use runtimelib::{dirs, ConnectionInfo, ExecutionState};
 12use smol::{net::TcpListener, process::Command};
 13use std::{
 14    env,
 15    fmt::Debug,
 16    net::{IpAddr, Ipv4Addr, SocketAddr},
 17    path::PathBuf,
 18    sync::Arc,
 19};
 20use uuid::Uuid;
 21
 22use crate::Session;
 23
 24use super::RunningKernel;
 25
 26#[derive(Debug, Clone)]
 27pub struct LocalKernelSpecification {
 28    pub name: String,
 29    pub path: PathBuf,
 30    pub kernelspec: JupyterKernelspec,
 31}
 32
 33impl PartialEq for LocalKernelSpecification {
 34    fn eq(&self, other: &Self) -> bool {
 35        self.name == other.name && self.path == other.path
 36    }
 37}
 38
 39impl Eq for LocalKernelSpecification {}
 40
 41impl LocalKernelSpecification {
 42    #[must_use]
 43    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 44        let argv = &self.kernelspec.argv;
 45
 46        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 47        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 48        anyhow::ensure!(
 49            argv.iter().any(|arg| arg == "{connection_file}"),
 50            "Missing 'connection_file' in argv in kernelspec {}",
 51            self.name
 52        );
 53
 54        let mut cmd = util::command::new_smol_command(&argv[0]);
 55
 56        for arg in &argv[1..] {
 57            if arg == "{connection_file}" {
 58                cmd.arg(connection_path);
 59            } else {
 60                cmd.arg(arg);
 61            }
 62        }
 63
 64        if let Some(env) = &self.kernelspec.env {
 65            cmd.envs(env);
 66        }
 67
 68        Ok(cmd)
 69    }
 70}
 71
 72// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 73// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 74async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 75    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 76    addr_zeroport.set_port(0);
 77    let mut ports: [u16; 5] = [0; 5];
 78    for i in 0..5 {
 79        let listener = TcpListener::bind(addr_zeroport).await?;
 80        let addr = listener.local_addr()?;
 81        ports[i] = addr.port();
 82    }
 83    Ok(ports)
 84}
 85
 86pub struct NativeRunningKernel {
 87    pub process: smol::process::Child,
 88    _shell_task: Task<Result<()>>,
 89    _control_task: Task<Result<()>>,
 90    _routing_task: Task<Result<()>>,
 91    connection_path: PathBuf,
 92    _process_status_task: Option<Task<()>>,
 93    pub working_directory: PathBuf,
 94    pub request_tx: mpsc::Sender<JupyterMessage>,
 95    pub execution_state: ExecutionState,
 96    pub kernel_info: Option<KernelInfoReply>,
 97}
 98
 99impl Debug for NativeRunningKernel {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("RunningKernel")
102            .field("process", &self.process)
103            .finish()
104    }
105}
106
107impl NativeRunningKernel {
108    pub fn new(
109        kernel_specification: LocalKernelSpecification,
110        entity_id: EntityId,
111        working_directory: PathBuf,
112        fs: Arc<dyn Fs>,
113        // todo: convert to weak view
114        session: View<Session>,
115        cx: &mut WindowContext,
116    ) -> Task<Result<Box<dyn RunningKernel>>> {
117        cx.spawn(|cx| async move {
118            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
119            let ports = peek_ports(ip).await?;
120
121            let connection_info = ConnectionInfo {
122                transport: "tcp".to_string(),
123                ip: ip.to_string(),
124                stdin_port: ports[0],
125                control_port: ports[1],
126                hb_port: ports[2],
127                shell_port: ports[3],
128                iopub_port: ports[4],
129                signature_scheme: "hmac-sha256".to_string(),
130                key: uuid::Uuid::new_v4().to_string(),
131                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
132            };
133
134            let runtime_dir = dirs::runtime_dir();
135            fs.create_dir(&runtime_dir)
136                .await
137                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
138            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
139            let content = serde_json::to_string(&connection_info)?;
140            fs.atomic_write(connection_path.clone(), content).await?;
141
142            let mut cmd = kernel_specification.command(&connection_path)?;
143
144            let mut process = cmd
145                .current_dir(&working_directory)
146                .stdout(std::process::Stdio::piped())
147                .stderr(std::process::Stdio::piped())
148                .stdin(std::process::Stdio::piped())
149                .kill_on_drop(true)
150                .spawn()
151                .context("failed to start the kernel process")?;
152
153            let session_id = Uuid::new_v4().to_string();
154
155            let mut iopub_socket =
156                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
157                    .await?;
158            let mut shell_socket =
159                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
160            let mut control_socket =
161                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
162
163            let (request_tx, mut request_rx) =
164                futures::channel::mpsc::channel::<JupyterMessage>(100);
165
166            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
167            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
168
169            let mut messages_rx = SelectAll::new();
170            messages_rx.push(control_reply_rx);
171            messages_rx.push(shell_reply_rx);
172
173            cx.spawn({
174                let session = session.clone();
175
176                |mut cx| async move {
177                    while let Some(message) = messages_rx.next().await {
178                        session
179                            .update(&mut cx, |session, cx| {
180                                session.route(&message, cx);
181                            })
182                            .ok();
183                    }
184                    anyhow::Ok(())
185                }
186            })
187            .detach();
188
189            // iopub task
190            cx.spawn({
191                let session = session.clone();
192
193                |mut cx| async move {
194                    while let Ok(message) = iopub_socket.read().await {
195                        session
196                            .update(&mut cx, |session, cx| {
197                                session.route(&message, cx);
198                            })
199                            .ok();
200                    }
201                    anyhow::Ok(())
202                }
203            })
204            .detach();
205
206            let (mut control_request_tx, mut control_request_rx) =
207                futures::channel::mpsc::channel(100);
208            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
209
210            let routing_task = cx.background_executor().spawn({
211                async move {
212                    while let Some(message) = request_rx.next().await {
213                        match message.content {
214                            JupyterMessageContent::DebugRequest(_)
215                            | JupyterMessageContent::InterruptRequest(_)
216                            | JupyterMessageContent::ShutdownRequest(_) => {
217                                control_request_tx.send(message).await?;
218                            }
219                            _ => {
220                                shell_request_tx.send(message).await?;
221                            }
222                        }
223                    }
224                    anyhow::Ok(())
225                }
226            });
227
228            let shell_task = cx.background_executor().spawn({
229                async move {
230                    while let Some(message) = shell_request_rx.next().await {
231                        shell_socket.send(message).await.ok();
232                        let reply = shell_socket.read().await?;
233                        shell_reply_tx.send(reply).await?;
234                    }
235                    anyhow::Ok(())
236                }
237            });
238
239            let control_task = cx.background_executor().spawn({
240                async move {
241                    while let Some(message) = control_request_rx.next().await {
242                        control_socket.send(message).await.ok();
243                        let reply = control_socket.read().await?;
244                        control_reply_tx.send(reply).await?;
245                    }
246                    anyhow::Ok(())
247                }
248            });
249
250            let stderr = process.stderr.take();
251
252            cx.spawn(|mut _cx| async move {
253                if stderr.is_none() {
254                    return;
255                }
256                let reader = BufReader::new(stderr.unwrap());
257                let mut lines = reader.lines();
258                while let Some(Ok(line)) = lines.next().await {
259                    log::error!("kernel: {}", line);
260                }
261            })
262            .detach();
263
264            let stdout = process.stdout.take();
265
266            cx.spawn(|mut _cx| async move {
267                if stdout.is_none() {
268                    return;
269                }
270                let reader = BufReader::new(stdout.unwrap());
271                let mut lines = reader.lines();
272                while let Some(Ok(line)) = lines.next().await {
273                    log::info!("kernel: {}", line);
274                }
275            })
276            .detach();
277
278            let status = process.status();
279
280            let process_status_task = cx.spawn(|mut cx| async move {
281                let error_message = match status.await {
282                    Ok(status) => {
283                        if status.success() {
284                            log::info!("kernel process exited successfully");
285                            return;
286                        }
287
288                        format!("kernel process exited with status: {:?}", status)
289                    }
290                    Err(err) => {
291                        format!("kernel process exited with error: {:?}", err)
292                    }
293                };
294
295                log::error!("{}", error_message);
296
297                session
298                    .update(&mut cx, |session, cx| {
299                        session.kernel_errored(error_message, cx);
300
301                        cx.notify();
302                    })
303                    .ok();
304            });
305
306            anyhow::Ok(Box::new(Self {
307                process,
308                request_tx,
309                working_directory,
310                _process_status_task: Some(process_status_task),
311                _shell_task: shell_task,
312                _control_task: control_task,
313                _routing_task: routing_task,
314                connection_path,
315                execution_state: ExecutionState::Idle,
316                kernel_info: None,
317            }) as Box<dyn RunningKernel>)
318        })
319    }
320}
321
322impl RunningKernel for NativeRunningKernel {
323    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
324        self.request_tx.clone()
325    }
326
327    fn working_directory(&self) -> &PathBuf {
328        &self.working_directory
329    }
330
331    fn execution_state(&self) -> &ExecutionState {
332        &self.execution_state
333    }
334
335    fn set_execution_state(&mut self, state: ExecutionState) {
336        self.execution_state = state;
337    }
338
339    fn kernel_info(&self) -> Option<&KernelInfoReply> {
340        self.kernel_info.as_ref()
341    }
342
343    fn set_kernel_info(&mut self, info: KernelInfoReply) {
344        self.kernel_info = Some(info);
345    }
346
347    fn force_shutdown(&mut self, _cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
348        self._process_status_task.take();
349        self.request_tx.close_channel();
350
351        Task::ready(match self.process.kill() {
352            Ok(_) => Ok(()),
353            Err(error) => Err(anyhow::anyhow!(
354                "Failed to kill the kernel process: {}",
355                error
356            )),
357        })
358    }
359}
360
361impl Drop for NativeRunningKernel {
362    fn drop(&mut self) {
363        std::fs::remove_file(&self.connection_path).ok();
364        self.request_tx.close_channel();
365        self.process.kill().ok();
366    }
367}
368
369async fn read_kernelspec_at(
370    // Path should be a directory to a jupyter kernelspec, as in
371    // /usr/local/share/jupyter/kernels/python3
372    kernel_dir: PathBuf,
373    fs: &dyn Fs,
374) -> Result<LocalKernelSpecification> {
375    let path = kernel_dir;
376    let kernel_name = if let Some(kernel_name) = path.file_name() {
377        kernel_name.to_string_lossy().to_string()
378    } else {
379        anyhow::bail!("Invalid kernelspec directory: {path:?}");
380    };
381
382    if !fs.is_dir(path.as_path()).await {
383        anyhow::bail!("Not a directory: {path:?}");
384    }
385
386    let expected_kernel_json = path.join("kernel.json");
387    let spec = fs.load(expected_kernel_json.as_path()).await?;
388    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
389
390    Ok(LocalKernelSpecification {
391        name: kernel_name,
392        path,
393        kernelspec: spec,
394    })
395}
396
397/// Read a directory of kernelspec directories
398async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
399    let mut kernelspec_dirs = fs.read_dir(&path).await?;
400
401    let mut valid_kernelspecs = Vec::new();
402    while let Some(path) = kernelspec_dirs.next().await {
403        match path {
404            Ok(path) => {
405                if fs.is_dir(path.as_path()).await {
406                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
407                        valid_kernelspecs.push(kernelspec);
408                    }
409                }
410            }
411            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
412        }
413    }
414
415    Ok(valid_kernelspecs)
416}
417
418pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
419    let mut data_dirs = dirs::data_dirs();
420
421    // Pick up any kernels from conda or conda environment
422    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
423        let conda_prefix = PathBuf::from(conda_prefix);
424        let conda_data_dir = conda_prefix.join("share").join("jupyter");
425        data_dirs.push(conda_data_dir);
426    }
427
428    // Search for kernels inside the base python environment
429    let command = util::command::new_smol_command("python")
430        .arg("-c")
431        .arg("import sys; print(sys.prefix)")
432        .output()
433        .await;
434
435    if let Ok(command) = command {
436        if command.status.success() {
437            let python_prefix = String::from_utf8(command.stdout);
438            if let Ok(python_prefix) = python_prefix {
439                let python_prefix = PathBuf::from(python_prefix.trim());
440                let python_data_dir = python_prefix.join("share").join("jupyter");
441                data_dirs.push(python_data_dir);
442            }
443        }
444    }
445
446    let kernel_dirs = data_dirs
447        .iter()
448        .map(|dir| dir.join("kernels"))
449        .map(|path| read_kernels_dir(path, fs.as_ref()))
450        .collect::<Vec<_>>();
451
452    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
453    let kernel_dirs = kernel_dirs
454        .into_iter()
455        .filter_map(Result::ok)
456        .flatten()
457        .collect::<Vec<_>>();
458
459    Ok(kernel_dirs)
460}
461
462#[cfg(test)]
463mod test {
464    use super::*;
465    use std::path::PathBuf;
466
467    use gpui::TestAppContext;
468    use project::FakeFs;
469    use serde_json::json;
470
471    #[gpui::test]
472    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
473        let fs = FakeFs::new(cx.executor());
474        fs.insert_tree(
475            "/jupyter",
476            json!({
477                ".zed": {
478                    "settings.json": r#"{ "tab_size": 8 }"#,
479                    "tasks.json": r#"[{
480                        "label": "cargo check",
481                        "command": "cargo",
482                        "args": ["check", "--all"]
483                    },]"#,
484                },
485                "kernels": {
486                    "python": {
487                        "kernel.json": r#"{
488                            "display_name": "Python 3",
489                            "language": "python",
490                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
491                            "env": {}
492                        }"#
493                    },
494                    "deno": {
495                        "kernel.json": r#"{
496                            "display_name": "Deno",
497                            "language": "typescript",
498                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
499                            "env": {}
500                        }"#
501                    }
502                },
503            }),
504        )
505        .await;
506
507        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
508            .await
509            .unwrap();
510
511        kernels.sort_by(|a, b| a.name.cmp(&b.name));
512
513        assert_eq!(
514            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
515            vec!["deno", "python"]
516        );
517    }
518}