native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, SinkExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::{SelectAll, StreamExt},
  7};
  8use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::dirs;
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use crate::Session;
 26
 27use super::RunningKernel;
 28
 29#[derive(Debug, Clone)]
 30pub struct LocalKernelSpecification {
 31    pub name: String,
 32    pub path: PathBuf,
 33    pub kernelspec: JupyterKernelspec,
 34}
 35
 36impl PartialEq for LocalKernelSpecification {
 37    fn eq(&self, other: &Self) -> bool {
 38        self.name == other.name && self.path == other.path
 39    }
 40}
 41
 42impl Eq for LocalKernelSpecification {}
 43
 44impl LocalKernelSpecification {
 45    #[must_use]
 46    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 47        let argv = &self.kernelspec.argv;
 48
 49        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 50        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 51        anyhow::ensure!(
 52            argv.iter().any(|arg| arg == "{connection_file}"),
 53            "Missing 'connection_file' in argv in kernelspec {}",
 54            self.name
 55        );
 56
 57        let mut cmd = util::command::new_smol_command(&argv[0]);
 58
 59        for arg in &argv[1..] {
 60            if arg == "{connection_file}" {
 61                cmd.arg(connection_path);
 62            } else {
 63                cmd.arg(arg);
 64            }
 65        }
 66
 67        if let Some(env) = &self.kernelspec.env {
 68            cmd.envs(env);
 69        }
 70
 71        Ok(cmd)
 72    }
 73}
 74
 75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 78    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 79    addr_zeroport.set_port(0);
 80    let mut ports: [u16; 5] = [0; 5];
 81    for i in 0..5 {
 82        let listener = TcpListener::bind(addr_zeroport).await?;
 83        let addr = listener.local_addr()?;
 84        ports[i] = addr.port();
 85    }
 86    Ok(ports)
 87}
 88
 89pub struct NativeRunningKernel {
 90    pub process: smol::process::Child,
 91    _shell_task: Task<Result<()>>,
 92    _control_task: Task<Result<()>>,
 93    _routing_task: Task<Result<()>>,
 94    connection_path: PathBuf,
 95    _process_status_task: Option<Task<()>>,
 96    pub working_directory: PathBuf,
 97    pub request_tx: mpsc::Sender<JupyterMessage>,
 98    pub execution_state: ExecutionState,
 99    pub kernel_info: Option<KernelInfoReply>,
100}
101
102impl Debug for NativeRunningKernel {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("RunningKernel")
105            .field("process", &self.process)
106            .finish()
107    }
108}
109
110impl NativeRunningKernel {
111    pub fn new(
112        kernel_specification: LocalKernelSpecification,
113        entity_id: EntityId,
114        working_directory: PathBuf,
115        fs: Arc<dyn Fs>,
116        // todo: convert to weak view
117        session: Entity<Session>,
118        window: &mut Window,
119        cx: &mut App,
120    ) -> Task<Result<Box<dyn RunningKernel>>> {
121        window.spawn(cx, async move |cx| {
122            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
123            let ports = peek_ports(ip).await?;
124
125            let connection_info = ConnectionInfo {
126                transport: Transport::TCP,
127                ip: ip.to_string(),
128                stdin_port: ports[0],
129                control_port: ports[1],
130                hb_port: ports[2],
131                shell_port: ports[3],
132                iopub_port: ports[4],
133                signature_scheme: "hmac-sha256".to_string(),
134                key: uuid::Uuid::new_v4().to_string(),
135                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
136            };
137
138            let runtime_dir = dirs::runtime_dir();
139            fs.create_dir(&runtime_dir)
140                .await
141                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
142            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
143            let content = serde_json::to_string(&connection_info)?;
144            fs.atomic_write(connection_path.clone(), content).await?;
145
146            let mut cmd = kernel_specification.command(&connection_path)?;
147
148            let mut process = cmd
149                .current_dir(&working_directory)
150                .stdout(std::process::Stdio::piped())
151                .stderr(std::process::Stdio::piped())
152                .stdin(std::process::Stdio::piped())
153                .kill_on_drop(true)
154                .spawn()
155                .context("failed to start the kernel process")?;
156
157            let session_id = Uuid::new_v4().to_string();
158
159            let mut iopub_socket =
160                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
161                    .await?;
162            let mut shell_socket =
163                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
164            let mut control_socket =
165                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
166
167            let (request_tx, mut request_rx) =
168                futures::channel::mpsc::channel::<JupyterMessage>(100);
169
170            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
171            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
172
173            let mut messages_rx = SelectAll::new();
174            messages_rx.push(control_reply_rx);
175            messages_rx.push(shell_reply_rx);
176
177            cx.spawn({
178                let session = session.clone();
179
180                async move |cx| {
181                    while let Some(message) = messages_rx.next().await {
182                        session
183                            .update_in(cx, |session, window, cx| {
184                                session.route(&message, window, cx);
185                            })
186                            .ok();
187                    }
188                    anyhow::Ok(())
189                }
190            })
191            .detach();
192
193            // iopub task
194            cx.spawn({
195                let session = session.clone();
196
197                async move |cx| {
198                    while let Ok(message) = iopub_socket.read().await {
199                        session
200                            .update_in(cx, |session, window, cx| {
201                                session.route(&message, window, cx);
202                            })
203                            .ok();
204                    }
205                    anyhow::Ok(())
206                }
207            })
208            .detach();
209
210            let (mut control_request_tx, mut control_request_rx) =
211                futures::channel::mpsc::channel(100);
212            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
213
214            let routing_task = cx.background_spawn({
215                async move {
216                    while let Some(message) = request_rx.next().await {
217                        match message.content {
218                            JupyterMessageContent::DebugRequest(_)
219                            | JupyterMessageContent::InterruptRequest(_)
220                            | JupyterMessageContent::ShutdownRequest(_) => {
221                                control_request_tx.send(message).await?;
222                            }
223                            _ => {
224                                shell_request_tx.send(message).await?;
225                            }
226                        }
227                    }
228                    anyhow::Ok(())
229                }
230            });
231
232            let shell_task = cx.background_spawn({
233                async move {
234                    while let Some(message) = shell_request_rx.next().await {
235                        shell_socket.send(message).await.ok();
236                        let reply = shell_socket.read().await?;
237                        shell_reply_tx.send(reply).await?;
238                    }
239                    anyhow::Ok(())
240                }
241            });
242
243            let control_task = cx.background_spawn({
244                async move {
245                    while let Some(message) = control_request_rx.next().await {
246                        control_socket.send(message).await.ok();
247                        let reply = control_socket.read().await?;
248                        control_reply_tx.send(reply).await?;
249                    }
250                    anyhow::Ok(())
251                }
252            });
253
254            let stderr = process.stderr.take();
255
256            cx.spawn(async move |_cx| {
257                if stderr.is_none() {
258                    return;
259                }
260                let reader = BufReader::new(stderr.unwrap());
261                let mut lines = reader.lines();
262                while let Some(Ok(line)) = lines.next().await {
263                    log::error!("kernel: {}", line);
264                }
265            })
266            .detach();
267
268            let stdout = process.stdout.take();
269
270            cx.spawn(async move |_cx| {
271                if stdout.is_none() {
272                    return;
273                }
274                let reader = BufReader::new(stdout.unwrap());
275                let mut lines = reader.lines();
276                while let Some(Ok(line)) = lines.next().await {
277                    log::info!("kernel: {}", line);
278                }
279            })
280            .detach();
281
282            let status = process.status();
283
284            let process_status_task = cx.spawn(async move |cx| {
285                let error_message = match status.await {
286                    Ok(status) => {
287                        if status.success() {
288                            log::info!("kernel process exited successfully");
289                            return;
290                        }
291
292                        format!("kernel process exited with status: {:?}", status)
293                    }
294                    Err(err) => {
295                        format!("kernel process exited with error: {:?}", err)
296                    }
297                };
298
299                log::error!("{}", error_message);
300
301                session
302                    .update(cx, |session, cx| {
303                        session.kernel_errored(error_message, cx);
304
305                        cx.notify();
306                    })
307                    .ok();
308            });
309
310            anyhow::Ok(Box::new(Self {
311                process,
312                request_tx,
313                working_directory,
314                _process_status_task: Some(process_status_task),
315                _shell_task: shell_task,
316                _control_task: control_task,
317                _routing_task: routing_task,
318                connection_path,
319                execution_state: ExecutionState::Idle,
320                kernel_info: None,
321            }) as Box<dyn RunningKernel>)
322        })
323    }
324}
325
326impl RunningKernel for NativeRunningKernel {
327    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
328        self.request_tx.clone()
329    }
330
331    fn working_directory(&self) -> &PathBuf {
332        &self.working_directory
333    }
334
335    fn execution_state(&self) -> &ExecutionState {
336        &self.execution_state
337    }
338
339    fn set_execution_state(&mut self, state: ExecutionState) {
340        self.execution_state = state;
341    }
342
343    fn kernel_info(&self) -> Option<&KernelInfoReply> {
344        self.kernel_info.as_ref()
345    }
346
347    fn set_kernel_info(&mut self, info: KernelInfoReply) {
348        self.kernel_info = Some(info);
349    }
350
351    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
352        self._process_status_task.take();
353        self.request_tx.close_channel();
354
355        Task::ready(match self.process.kill() {
356            Ok(_) => Ok(()),
357            Err(error) => Err(anyhow::anyhow!(
358                "Failed to kill the kernel process: {}",
359                error
360            )),
361        })
362    }
363}
364
365impl Drop for NativeRunningKernel {
366    fn drop(&mut self) {
367        std::fs::remove_file(&self.connection_path).ok();
368        self.request_tx.close_channel();
369        self.process.kill().ok();
370    }
371}
372
373async fn read_kernelspec_at(
374    // Path should be a directory to a jupyter kernelspec, as in
375    // /usr/local/share/jupyter/kernels/python3
376    kernel_dir: PathBuf,
377    fs: &dyn Fs,
378) -> Result<LocalKernelSpecification> {
379    let path = kernel_dir;
380    let kernel_name = if let Some(kernel_name) = path.file_name() {
381        kernel_name.to_string_lossy().to_string()
382    } else {
383        anyhow::bail!("Invalid kernelspec directory: {path:?}");
384    };
385
386    if !fs.is_dir(path.as_path()).await {
387        anyhow::bail!("Not a directory: {path:?}");
388    }
389
390    let expected_kernel_json = path.join("kernel.json");
391    let spec = fs.load(expected_kernel_json.as_path()).await?;
392    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
393
394    Ok(LocalKernelSpecification {
395        name: kernel_name,
396        path,
397        kernelspec: spec,
398    })
399}
400
401/// Read a directory of kernelspec directories
402async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
403    let mut kernelspec_dirs = fs.read_dir(&path).await?;
404
405    let mut valid_kernelspecs = Vec::new();
406    while let Some(path) = kernelspec_dirs.next().await {
407        match path {
408            Ok(path) => {
409                if fs.is_dir(path.as_path()).await {
410                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
411                        valid_kernelspecs.push(kernelspec);
412                    }
413                }
414            }
415            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
416        }
417    }
418
419    Ok(valid_kernelspecs)
420}
421
422pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
423    let mut data_dirs = dirs::data_dirs();
424
425    // Pick up any kernels from conda or conda environment
426    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
427        let conda_prefix = PathBuf::from(conda_prefix);
428        let conda_data_dir = conda_prefix.join("share").join("jupyter");
429        data_dirs.push(conda_data_dir);
430    }
431
432    // Search for kernels inside the base python environment
433    let command = util::command::new_smol_command("python")
434        .arg("-c")
435        .arg("import sys; print(sys.prefix)")
436        .output()
437        .await;
438
439    if let Ok(command) = command {
440        if command.status.success() {
441            let python_prefix = String::from_utf8(command.stdout);
442            if let Ok(python_prefix) = python_prefix {
443                let python_prefix = PathBuf::from(python_prefix.trim());
444                let python_data_dir = python_prefix.join("share").join("jupyter");
445                data_dirs.push(python_data_dir);
446            }
447        }
448    }
449
450    let kernel_dirs = data_dirs
451        .iter()
452        .map(|dir| dir.join("kernels"))
453        .map(|path| read_kernels_dir(path, fs.as_ref()))
454        .collect::<Vec<_>>();
455
456    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
457    let kernel_dirs = kernel_dirs
458        .into_iter()
459        .filter_map(Result::ok)
460        .flatten()
461        .collect::<Vec<_>>();
462
463    Ok(kernel_dirs)
464}
465
466#[cfg(test)]
467mod test {
468    use super::*;
469    use std::path::PathBuf;
470
471    use gpui::TestAppContext;
472    use project::FakeFs;
473    use serde_json::json;
474
475    #[gpui::test]
476    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
477        let fs = FakeFs::new(cx.executor());
478        fs.insert_tree(
479            "/jupyter",
480            json!({
481                ".zed": {
482                    "settings.json": r#"{ "tab_size": 8 }"#,
483                    "tasks.json": r#"[{
484                        "label": "cargo check",
485                        "command": "cargo",
486                        "args": ["check", "--all"]
487                    },]"#,
488                },
489                "kernels": {
490                    "python": {
491                        "kernel.json": r#"{
492                            "display_name": "Python 3",
493                            "language": "python",
494                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
495                            "env": {}
496                        }"#
497                    },
498                    "deno": {
499                        "kernel.json": r#"{
500                            "display_name": "Deno",
501                            "language": "typescript",
502                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
503                            "env": {}
504                        }"#
505                    }
506                },
507            }),
508        )
509        .await;
510
511        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
512            .await
513            .unwrap();
514
515        kernels.sort_by(|a, b| a.name.cmp(&b.name));
516
517        assert_eq!(
518            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
519            vec!["deno", "python"]
520        );
521    }
522}