1use anyhow::{Context as _, Result};
  2use futures::{
  3    AsyncBufReadExt as _, SinkExt as _,
  4    channel::mpsc::{self},
  5    io::BufReader,
  6    stream::{FuturesUnordered, SelectAll, StreamExt},
  7};
  8use gpui::{App, AppContext as _, Entity, EntityId, Task, Window};
  9use jupyter_protocol::{
 10    ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent, KernelInfoReply,
 11    connection_info::{ConnectionInfo, Transport},
 12};
 13use project::Fs;
 14use runtimelib::dirs;
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    net::{IpAddr, Ipv4Addr, SocketAddr},
 20    path::PathBuf,
 21    sync::Arc,
 22};
 23use uuid::Uuid;
 24
 25use crate::Session;
 26
 27use super::RunningKernel;
 28
 29#[derive(Debug, Clone)]
 30pub struct LocalKernelSpecification {
 31    pub name: String,
 32    pub path: PathBuf,
 33    pub kernelspec: JupyterKernelspec,
 34}
 35
 36impl PartialEq for LocalKernelSpecification {
 37    fn eq(&self, other: &Self) -> bool {
 38        self.name == other.name && self.path == other.path
 39    }
 40}
 41
 42impl Eq for LocalKernelSpecification {}
 43
 44impl LocalKernelSpecification {
 45    #[must_use]
 46    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
 47        let argv = &self.kernelspec.argv;
 48
 49        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
 50        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
 51        anyhow::ensure!(
 52            argv.iter().any(|arg| arg == "{connection_file}"),
 53            "Missing 'connection_file' in argv in kernelspec {}",
 54            self.name
 55        );
 56
 57        let mut cmd = util::command::new_smol_command(&argv[0]);
 58
 59        for arg in &argv[1..] {
 60            if arg == "{connection_file}" {
 61                cmd.arg(connection_path);
 62            } else {
 63                cmd.arg(arg);
 64            }
 65        }
 66
 67        if let Some(env) = &self.kernelspec.env {
 68            cmd.envs(env);
 69        }
 70
 71        Ok(cmd)
 72    }
 73}
 74
 75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
 76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
 77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
 78    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
 79    addr_zeroport.set_port(0);
 80    let mut ports: [u16; 5] = [0; 5];
 81    for i in 0..5 {
 82        let listener = TcpListener::bind(addr_zeroport).await?;
 83        let addr = listener.local_addr()?;
 84        ports[i] = addr.port();
 85    }
 86    Ok(ports)
 87}
 88
 89pub struct NativeRunningKernel {
 90    pub process: smol::process::Child,
 91    connection_path: PathBuf,
 92    _process_status_task: Option<Task<()>>,
 93    pub working_directory: PathBuf,
 94    pub request_tx: mpsc::Sender<JupyterMessage>,
 95    pub execution_state: ExecutionState,
 96    pub kernel_info: Option<KernelInfoReply>,
 97}
 98
 99impl Debug for NativeRunningKernel {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("RunningKernel")
102            .field("process", &self.process)
103            .finish()
104    }
105}
106
107impl NativeRunningKernel {
108    pub fn new(
109        kernel_specification: LocalKernelSpecification,
110        entity_id: EntityId,
111        working_directory: PathBuf,
112        fs: Arc<dyn Fs>,
113        // todo: convert to weak view
114        session: Entity<Session>,
115        window: &mut Window,
116        cx: &mut App,
117    ) -> Task<Result<Box<dyn RunningKernel>>> {
118        window.spawn(cx, async move |cx| {
119            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
120            let ports = peek_ports(ip).await?;
121
122            let connection_info = ConnectionInfo {
123                transport: Transport::TCP,
124                ip: ip.to_string(),
125                stdin_port: ports[0],
126                control_port: ports[1],
127                hb_port: ports[2],
128                shell_port: ports[3],
129                iopub_port: ports[4],
130                signature_scheme: "hmac-sha256".to_string(),
131                key: uuid::Uuid::new_v4().to_string(),
132                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
133            };
134
135            let runtime_dir = dirs::runtime_dir();
136            fs.create_dir(&runtime_dir)
137                .await
138                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
139            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
140            let content = serde_json::to_string(&connection_info)?;
141            fs.atomic_write(connection_path.clone(), content).await?;
142
143            let mut cmd = kernel_specification.command(&connection_path)?;
144
145            let mut process = cmd
146                .current_dir(&working_directory)
147                .stdout(std::process::Stdio::piped())
148                .stderr(std::process::Stdio::piped())
149                .stdin(std::process::Stdio::piped())
150                .kill_on_drop(true)
151                .spawn()
152                .context("failed to start the kernel process")?;
153
154            let session_id = Uuid::new_v4().to_string();
155
156            let mut iopub_socket =
157                runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
158                    .await?;
159            let mut shell_socket =
160                runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
161            let mut control_socket =
162                runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
163
164            let (request_tx, mut request_rx) =
165                futures::channel::mpsc::channel::<JupyterMessage>(100);
166
167            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
168            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
169
170            let mut messages_rx = SelectAll::new();
171            messages_rx.push(control_reply_rx);
172            messages_rx.push(shell_reply_rx);
173
174            cx.spawn({
175                let session = session.clone();
176
177                async move |cx| {
178                    while let Some(message) = messages_rx.next().await {
179                        session
180                            .update_in(cx, |session, window, cx| {
181                                session.route(&message, window, cx);
182                            })
183                            .ok();
184                    }
185                }
186            })
187            .detach();
188
189            // iopub task
190            let iopub_task = cx.spawn({
191                let session = session.clone();
192
193                async move |cx| -> anyhow::Result<()> {
194                    loop {
195                        let message = iopub_socket.read().await?;
196                        session
197                            .update_in(cx, |session, window, cx| {
198                                session.route(&message, window, cx);
199                            })
200                            .ok();
201                    }
202                }
203            });
204
205            let (mut control_request_tx, mut control_request_rx) =
206                futures::channel::mpsc::channel(100);
207            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
208
209            let routing_task = cx.background_spawn({
210                async move {
211                    while let Some(message) = request_rx.next().await {
212                        match message.content {
213                            JupyterMessageContent::DebugRequest(_)
214                            | JupyterMessageContent::InterruptRequest(_)
215                            | JupyterMessageContent::ShutdownRequest(_) => {
216                                control_request_tx.send(message).await?;
217                            }
218                            _ => {
219                                shell_request_tx.send(message).await?;
220                            }
221                        }
222                    }
223                    anyhow::Ok(())
224                }
225            });
226
227            let shell_task = cx.background_spawn({
228                async move {
229                    while let Some(message) = shell_request_rx.next().await {
230                        shell_socket.send(message).await.ok();
231                        let reply = shell_socket.read().await?;
232                        shell_reply_tx.send(reply).await?;
233                    }
234                    anyhow::Ok(())
235                }
236            });
237
238            let control_task = cx.background_spawn({
239                async move {
240                    while let Some(message) = control_request_rx.next().await {
241                        control_socket.send(message).await.ok();
242                        let reply = control_socket.read().await?;
243                        control_reply_tx.send(reply).await?;
244                    }
245                    anyhow::Ok(())
246                }
247            });
248
249            let stderr = process.stderr.take();
250
251            cx.spawn(async move |_cx| {
252                if stderr.is_none() {
253                    return;
254                }
255                let reader = BufReader::new(stderr.unwrap());
256                let mut lines = reader.lines();
257                while let Some(Ok(line)) = lines.next().await {
258                    log::error!("kernel: {}", line);
259                }
260            })
261            .detach();
262
263            let stdout = process.stdout.take();
264
265            cx.spawn(async move |_cx| {
266                if stdout.is_none() {
267                    return;
268                }
269                let reader = BufReader::new(stdout.unwrap());
270                let mut lines = reader.lines();
271                while let Some(Ok(line)) = lines.next().await {
272                    log::info!("kernel: {}", line);
273                }
274            })
275            .detach();
276
277            cx.spawn({
278                let session = session.clone();
279                async move |cx| {
280                    async fn with_name(
281                        name: &'static str,
282                        task: Task<Result<()>>,
283                    ) -> (&'static str, Result<()>) {
284                        (name, task.await)
285                    }
286
287                    let mut tasks = FuturesUnordered::new();
288                    tasks.push(with_name("iopub task", iopub_task));
289                    tasks.push(with_name("shell task", shell_task));
290                    tasks.push(with_name("control task", control_task));
291                    tasks.push(with_name("routing task", routing_task));
292
293                    while let Some((name, result)) = tasks.next().await {
294                        if let Err(err) = result {
295                            log::error!("kernel: handling failed for {name}: {err:?}");
296
297                            session
298                                .update(cx, |session, cx| {
299                                    session.kernel_errored(
300                                        format!("handling failed for {name}: {err}"),
301                                        cx,
302                                    );
303                                    cx.notify();
304                                })
305                                .ok();
306                        }
307                    }
308                }
309            })
310            .detach();
311
312            let status = process.status();
313
314            let process_status_task = cx.spawn(async move |cx| {
315                let error_message = match status.await {
316                    Ok(status) => {
317                        if status.success() {
318                            log::info!("kernel process exited successfully");
319                            return;
320                        }
321
322                        format!("kernel process exited with status: {:?}", status)
323                    }
324                    Err(err) => {
325                        format!("kernel process exited with error: {:?}", err)
326                    }
327                };
328
329                log::error!("{}", error_message);
330
331                session
332                    .update(cx, |session, cx| {
333                        session.kernel_errored(error_message, cx);
334
335                        cx.notify();
336                    })
337                    .ok();
338            });
339
340            anyhow::Ok(Box::new(Self {
341                process,
342                request_tx,
343                working_directory,
344                _process_status_task: Some(process_status_task),
345                connection_path,
346                execution_state: ExecutionState::Idle,
347                kernel_info: None,
348            }) as Box<dyn RunningKernel>)
349        })
350    }
351}
352
353impl RunningKernel for NativeRunningKernel {
354    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
355        self.request_tx.clone()
356    }
357
358    fn working_directory(&self) -> &PathBuf {
359        &self.working_directory
360    }
361
362    fn execution_state(&self) -> &ExecutionState {
363        &self.execution_state
364    }
365
366    fn set_execution_state(&mut self, state: ExecutionState) {
367        self.execution_state = state;
368    }
369
370    fn kernel_info(&self) -> Option<&KernelInfoReply> {
371        self.kernel_info.as_ref()
372    }
373
374    fn set_kernel_info(&mut self, info: KernelInfoReply) {
375        self.kernel_info = Some(info);
376    }
377
378    fn force_shutdown(&mut self, _window: &mut Window, _cx: &mut App) -> Task<anyhow::Result<()>> {
379        self._process_status_task.take();
380        self.request_tx.close_channel();
381        Task::ready(self.process.kill().context("killing the kernel process"))
382    }
383}
384
385impl Drop for NativeRunningKernel {
386    fn drop(&mut self) {
387        std::fs::remove_file(&self.connection_path).ok();
388        self.request_tx.close_channel();
389        self.process.kill().ok();
390    }
391}
392
393async fn read_kernelspec_at(
394    // Path should be a directory to a jupyter kernelspec, as in
395    // /usr/local/share/jupyter/kernels/python3
396    kernel_dir: PathBuf,
397    fs: &dyn Fs,
398) -> Result<LocalKernelSpecification> {
399    let path = kernel_dir;
400    let kernel_name = if let Some(kernel_name) = path.file_name() {
401        kernel_name.to_string_lossy().into_owned()
402    } else {
403        anyhow::bail!("Invalid kernelspec directory: {path:?}");
404    };
405
406    if !fs.is_dir(path.as_path()).await {
407        anyhow::bail!("Not a directory: {path:?}");
408    }
409
410    let expected_kernel_json = path.join("kernel.json");
411    let spec = fs.load(expected_kernel_json.as_path()).await?;
412    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
413
414    Ok(LocalKernelSpecification {
415        name: kernel_name,
416        path,
417        kernelspec: spec,
418    })
419}
420
421/// Read a directory of kernelspec directories
422async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
423    let mut kernelspec_dirs = fs.read_dir(&path).await?;
424
425    let mut valid_kernelspecs = Vec::new();
426    while let Some(path) = kernelspec_dirs.next().await {
427        match path {
428            Ok(path) => {
429                if fs.is_dir(path.as_path()).await
430                    && let Ok(kernelspec) = read_kernelspec_at(path, fs).await
431                {
432                    valid_kernelspecs.push(kernelspec);
433                }
434            }
435            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
436        }
437    }
438
439    Ok(valid_kernelspecs)
440}
441
442pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
443    let mut data_dirs = dirs::data_dirs();
444
445    // Pick up any kernels from conda or conda environment
446    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
447        let conda_prefix = PathBuf::from(conda_prefix);
448        let conda_data_dir = conda_prefix.join("share").join("jupyter");
449        data_dirs.push(conda_data_dir);
450    }
451
452    // Search for kernels inside the base python environment
453    let command = util::command::new_smol_command("python")
454        .arg("-c")
455        .arg("import sys; print(sys.prefix)")
456        .output()
457        .await;
458
459    if let Ok(command) = command
460        && command.status.success()
461    {
462        let python_prefix = String::from_utf8(command.stdout);
463        if let Ok(python_prefix) = python_prefix {
464            let python_prefix = PathBuf::from(python_prefix.trim());
465            let python_data_dir = python_prefix.join("share").join("jupyter");
466            data_dirs.push(python_data_dir);
467        }
468    }
469
470    let kernel_dirs = data_dirs
471        .iter()
472        .map(|dir| dir.join("kernels"))
473        .map(|path| read_kernels_dir(path, fs.as_ref()))
474        .collect::<Vec<_>>();
475
476    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
477    let kernel_dirs = kernel_dirs
478        .into_iter()
479        .filter_map(Result::ok)
480        .flatten()
481        .collect::<Vec<_>>();
482
483    Ok(kernel_dirs)
484}
485
486#[cfg(test)]
487mod test {
488    use super::*;
489    use std::path::PathBuf;
490
491    use gpui::TestAppContext;
492    use project::FakeFs;
493    use serde_json::json;
494
495    #[gpui::test]
496    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
497        let fs = FakeFs::new(cx.executor());
498        fs.insert_tree(
499            "/jupyter",
500            json!({
501                ".zed": {
502                    "settings.json": r#"{ "tab_size": 8 }"#,
503                    "tasks.json": r#"[{
504                        "label": "cargo check",
505                        "command": "cargo",
506                        "args": ["check", "--all"]
507                    },]"#,
508                },
509                "kernels": {
510                    "python": {
511                        "kernel.json": r#"{
512                            "display_name": "Python 3",
513                            "language": "python",
514                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
515                            "env": {}
516                        }"#
517                    },
518                    "deno": {
519                        "kernel.json": r#"{
520                            "display_name": "Deno",
521                            "language": "typescript",
522                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
523                            "env": {}
524                        }"#
525                    }
526                },
527            }),
528        )
529        .await;
530
531        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
532            .await
533            .unwrap();
534
535        kernels.sort_by(|a, b| a.name.cmp(&b.name));
536
537        assert_eq!(
538            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
539            vec!["deno", "python"]
540        );
541    }
542}