native_kernel.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self},
  4    stream::{SelectAll, StreamExt},
  5    SinkExt as _,
  6};
  7use gpui::{AppContext, EntityId, Task};
  8use jupyter_protocol::{JupyterMessage, JupyterMessageContent, KernelInfoReply};
  9use project::Fs;
 10use runtimelib::{dirs, ConnectionInfo, ExecutionState, JupyterKernelspec};
 11use smol::{net::TcpListener, process::Command};
 12use std::{
 13    env,
 14    fmt::Debug,
 15    net::{IpAddr, Ipv4Addr, SocketAddr},
 16    path::PathBuf,
 17    sync::Arc,
 18};
 19use uuid::Uuid;
 20
 21use super::{JupyterMessageChannel, RunningKernel};
 22
 23#[derive(Debug, Clone)]
 24pub struct LocalKernelSpecification {
 25    pub name: String,
 26    pub path: PathBuf,
 27    pub kernelspec: JupyterKernelspec,
 28}
 29
 30impl PartialEq for LocalKernelSpecification {
 31    fn eq(&self, other: &Self) -> bool {
 32        self.name == other.name && self.path == other.path
 33    }
 34}
 35
 36impl Eq for LocalKernelSpecification {}
 37
 38impl LocalKernelSpecification {
 39    #[must_use]
 40    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 41        let argv = &self.kernelspec.argv;
 42
 43        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 44        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 45        anyhow::ensure!(
 46            argv.iter().any(|arg| arg == "{connection_file}"),
 47            "Missing 'connection_file' in argv in kernelspec {}",
 48            self.name
 49        );
 50
 51        let mut cmd = Command::new(&argv[0]);
 52
 53        for arg in &argv[1..] {
 54            if arg == "{connection_file}" {
 55                cmd.arg(connection_path);
 56            } else {
 57                cmd.arg(arg);
 58            }
 59        }
 60
 61        if let Some(env) = &self.kernelspec.env {
 62            cmd.envs(env);
 63        }
 64
 65        #[cfg(windows)]
 66        {
 67            use smol::process::windows::CommandExt;
 68            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
 69        }
 70
 71        Ok(cmd)
 72    }
 73}
 74
 75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 78    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 79    addr_zeroport.set_port(0);
 80    let mut ports: [u16; 5] = [0; 5];
 81    for i in 0..5 {
 82        let listener = TcpListener::bind(addr_zeroport).await?;
 83        let addr = listener.local_addr()?;
 84        ports[i] = addr.port();
 85    }
 86    Ok(ports)
 87}
 88
 89pub struct NativeRunningKernel {
 90    pub process: smol::process::Child,
 91    _shell_task: Task<Result<()>>,
 92    _iopub_task: Task<Result<()>>,
 93    _control_task: Task<Result<()>>,
 94    _routing_task: Task<Result<()>>,
 95    connection_path: PathBuf,
 96    pub working_directory: PathBuf,
 97    pub request_tx: mpsc::Sender<JupyterMessage>,
 98    pub execution_state: ExecutionState,
 99    pub kernel_info: Option<KernelInfoReply>,
100}
101
102impl Debug for NativeRunningKernel {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("RunningKernel")
105            .field("process", &self.process)
106            .finish()
107    }
108}
109
110impl NativeRunningKernel {
111    pub fn new(
112        kernel_specification: LocalKernelSpecification,
113        entity_id: EntityId,
114        working_directory: PathBuf,
115        fs: Arc<dyn Fs>,
116        cx: &mut AppContext,
117    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
118        cx.spawn(|cx| async move {
119            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
120            let ports = peek_ports(ip).await?;
121
122            let connection_info = ConnectionInfo {
123                transport: "tcp".to_string(),
124                ip: ip.to_string(),
125                stdin_port: ports[0],
126                control_port: ports[1],
127                hb_port: ports[2],
128                shell_port: ports[3],
129                iopub_port: ports[4],
130                signature_scheme: "hmac-sha256".to_string(),
131                key: uuid::Uuid::new_v4().to_string(),
132                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
133            };
134
135            let runtime_dir = dirs::runtime_dir();
136            fs.create_dir(&runtime_dir)
137                .await
138                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
139            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
140            let content = serde_json::to_string(&connection_info)?;
141            fs.atomic_write(connection_path.clone(), content).await?;
142
143            let mut cmd = kernel_specification.command(&connection_path)?;
144
145            let process = cmd
146                .current_dir(&working_directory)
147                .stdout(std::process::Stdio::piped())
148                .stderr(std::process::Stdio::piped())
149                .stdin(std::process::Stdio::piped())
150                .kill_on_drop(true)
151                .spawn()
152                .context("failed to start the kernel process")?;
153
154            let session_id = Uuid::new_v4().to_string();
155
156            let mut iopub_socket =
157                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
158                    .await?;
159            let mut shell_socket =
160                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
161            let mut control_socket =
162                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
163
164            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
165
166            let (request_tx, mut request_rx) =
167                futures::channel::mpsc::channel::<JupyterMessage>(100);
168
169            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
170            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
171
172            let mut messages_rx = SelectAll::new();
173            messages_rx.push(iosub);
174            messages_rx.push(control_reply_rx);
175            messages_rx.push(shell_reply_rx);
176
177            let iopub_task = cx.background_executor().spawn({
178                async move {
179                    while let Ok(message) = iopub_socket.read().await {
180                        iopub.send(message).await?;
181                    }
182                    anyhow::Ok(())
183                }
184            });
185
186            let (mut control_request_tx, mut control_request_rx) =
187                futures::channel::mpsc::channel(100);
188            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
189
190            let routing_task = cx.background_executor().spawn({
191                async move {
192                    while let Some(message) = request_rx.next().await {
193                        match message.content {
194                            JupyterMessageContent::DebugRequest(_)
195                            | JupyterMessageContent::InterruptRequest(_)
196                            | JupyterMessageContent::ShutdownRequest(_) => {
197                                control_request_tx.send(message).await?;
198                            }
199                            _ => {
200                                shell_request_tx.send(message).await?;
201                            }
202                        }
203                    }
204                    anyhow::Ok(())
205                }
206            });
207
208            let shell_task = cx.background_executor().spawn({
209                async move {
210                    while let Some(message) = shell_request_rx.next().await {
211                        shell_socket.send(message).await.ok();
212                        let reply = shell_socket.read().await?;
213                        shell_reply_tx.send(reply).await?;
214                    }
215                    anyhow::Ok(())
216                }
217            });
218
219            let control_task = cx.background_executor().spawn({
220                async move {
221                    while let Some(message) = control_request_rx.next().await {
222                        control_socket.send(message).await.ok();
223                        let reply = control_socket.read().await?;
224                        control_reply_tx.send(reply).await?;
225                    }
226                    anyhow::Ok(())
227                }
228            });
229
230            anyhow::Ok((
231                Self {
232                    process,
233                    request_tx,
234                    working_directory,
235                    _shell_task: shell_task,
236                    _iopub_task: iopub_task,
237                    _control_task: control_task,
238                    _routing_task: routing_task,
239                    connection_path,
240                    execution_state: ExecutionState::Idle,
241                    kernel_info: None,
242                },
243                messages_rx,
244            ))
245        })
246    }
247}
248
249impl RunningKernel for NativeRunningKernel {
250    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
251        self.request_tx.clone()
252    }
253
254    fn working_directory(&self) -> &PathBuf {
255        &self.working_directory
256    }
257
258    fn execution_state(&self) -> &ExecutionState {
259        &self.execution_state
260    }
261
262    fn set_execution_state(&mut self, state: ExecutionState) {
263        self.execution_state = state;
264    }
265
266    fn kernel_info(&self) -> Option<&KernelInfoReply> {
267        self.kernel_info.as_ref()
268    }
269
270    fn set_kernel_info(&mut self, info: KernelInfoReply) {
271        self.kernel_info = Some(info);
272    }
273
274    fn force_shutdown(&mut self) -> anyhow::Result<()> {
275        match self.process.kill() {
276            Ok(_) => Ok(()),
277            Err(error) => Err(anyhow::anyhow!(
278                "Failed to kill the kernel process: {}",
279                error
280            )),
281        }
282    }
283}
284
285impl Drop for NativeRunningKernel {
286    fn drop(&mut self) {
287        std::fs::remove_file(&self.connection_path).ok();
288        self.request_tx.close_channel();
289        self.process.kill().ok();
290    }
291}
292
293async fn read_kernelspec_at(
294    // Path should be a directory to a jupyter kernelspec, as in
295    // /usr/local/share/jupyter/kernels/python3
296    kernel_dir: PathBuf,
297    fs: &dyn Fs,
298) -> Result<LocalKernelSpecification> {
299    let path = kernel_dir;
300    let kernel_name = if let Some(kernel_name) = path.file_name() {
301        kernel_name.to_string_lossy().to_string()
302    } else {
303        anyhow::bail!("Invalid kernelspec directory: {path:?}");
304    };
305
306    if !fs.is_dir(path.as_path()).await {
307        anyhow::bail!("Not a directory: {path:?}");
308    }
309
310    let expected_kernel_json = path.join("kernel.json");
311    let spec = fs.load(expected_kernel_json.as_path()).await?;
312    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
313
314    Ok(LocalKernelSpecification {
315        name: kernel_name,
316        path,
317        kernelspec: spec,
318    })
319}
320
321/// Read a directory of kernelspec directories
322async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
323    let mut kernelspec_dirs = fs.read_dir(&path).await?;
324
325    let mut valid_kernelspecs = Vec::new();
326    while let Some(path) = kernelspec_dirs.next().await {
327        match path {
328            Ok(path) => {
329                if fs.is_dir(path.as_path()).await {
330                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
331                        valid_kernelspecs.push(kernelspec);
332                    }
333                }
334            }
335            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
336        }
337    }
338
339    Ok(valid_kernelspecs)
340}
341
342pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
343    let mut data_dirs = dirs::data_dirs();
344
345    // Pick up any kernels from conda or conda environment
346    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
347        let conda_prefix = PathBuf::from(conda_prefix);
348        let conda_data_dir = conda_prefix.join("share").join("jupyter");
349        data_dirs.push(conda_data_dir);
350    }
351
352    // Search for kernels inside the base python environment
353    let mut command = Command::new("python");
354    command.arg("-c");
355    command.arg("import sys; print(sys.prefix)");
356
357    #[cfg(windows)]
358    {
359        use smol::process::windows::CommandExt;
360        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
361    }
362
363    let command = command.output().await;
364
365    if let Ok(command) = command {
366        if command.status.success() {
367            let python_prefix = String::from_utf8(command.stdout);
368            if let Ok(python_prefix) = python_prefix {
369                let python_prefix = PathBuf::from(python_prefix.trim());
370                let python_data_dir = python_prefix.join("share").join("jupyter");
371                data_dirs.push(python_data_dir);
372            }
373        }
374    }
375
376    let kernel_dirs = data_dirs
377        .iter()
378        .map(|dir| dir.join("kernels"))
379        .map(|path| read_kernels_dir(path, fs.as_ref()))
380        .collect::<Vec<_>>();
381
382    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
383    let kernel_dirs = kernel_dirs
384        .into_iter()
385        .filter_map(Result::ok)
386        .flatten()
387        .collect::<Vec<_>>();
388
389    Ok(kernel_dirs)
390}
391
392#[cfg(test)]
393mod test {
394    use super::*;
395    use std::path::PathBuf;
396
397    use gpui::TestAppContext;
398    use project::FakeFs;
399    use serde_json::json;
400
401    #[gpui::test]
402    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
403        let fs = FakeFs::new(cx.executor());
404        fs.insert_tree(
405            "/jupyter",
406            json!({
407                ".zed": {
408                    "settings.json": r#"{ "tab_size": 8 }"#,
409                    "tasks.json": r#"[{
410                        "label": "cargo check",
411                        "command": "cargo",
412                        "args": ["check", "--all"]
413                    },]"#,
414                },
415                "kernels": {
416                    "python": {
417                        "kernel.json": r#"{
418                            "display_name": "Python 3",
419                            "language": "python",
420                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
421                            "env": {}
422                        }"#
423                    },
424                    "deno": {
425                        "kernel.json": r#"{
426                            "display_name": "Deno",
427                            "language": "typescript",
428                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
429                            "env": {}
430                        }"#
431                    }
432                },
433            }),
434        )
435        .await;
436
437        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
438            .await
439            .unwrap();
440
441        kernels.sort_by(|a, b| a.name.cmp(&b.name));
442
443        assert_eq!(
444            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
445            vec!["deno", "python"]
446        );
447    }
448}