kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Model, Task};
  9use language::LanguageName;
 10use project::{Fs, Project, WorktreeId};
 11use runtimelib::{
 12    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 13    KernelInfoReply,
 14};
 15use smol::{net::TcpListener, process::Command};
 16use std::{
 17    env,
 18    fmt::Debug,
 19    future::Future,
 20    net::{IpAddr, Ipv4Addr, SocketAddr},
 21    path::PathBuf,
 22    sync::Arc,
 23};
 24use ui::SharedString;
 25use uuid::Uuid;
 26
 27#[derive(Debug, Clone, PartialEq, Eq)]
 28pub enum KernelSpecification {
 29    Remote(RemoteKernelSpecification),
 30    Jupyter(LocalKernelSpecification),
 31    PythonEnv(LocalKernelSpecification),
 32}
 33
 34impl KernelSpecification {
 35    pub fn name(&self) -> SharedString {
 36        match self {
 37            Self::Jupyter(spec) => spec.name.clone().into(),
 38            Self::PythonEnv(spec) => spec.name.clone().into(),
 39            Self::Remote(spec) => spec.name.clone().into(),
 40        }
 41    }
 42
 43    pub fn type_name(&self) -> SharedString {
 44        match self {
 45            Self::Jupyter(_) => "Jupyter".into(),
 46            Self::PythonEnv(_) => "Python Environment".into(),
 47            Self::Remote(_) => "Remote".into(),
 48        }
 49    }
 50
 51    pub fn path(&self) -> SharedString {
 52        SharedString::from(match self {
 53            Self::Jupyter(spec) => spec.path.to_string_lossy().to_string(),
 54            Self::PythonEnv(spec) => spec.path.to_string_lossy().to_string(),
 55            Self::Remote(spec) => spec.url.to_string(),
 56        })
 57    }
 58
 59    pub fn language(&self) -> SharedString {
 60        SharedString::from(match self {
 61            Self::Jupyter(spec) => spec.kernelspec.language.clone(),
 62            Self::PythonEnv(spec) => spec.kernelspec.language.clone(),
 63            Self::Remote(spec) => spec.kernelspec.language.clone(),
 64        })
 65    }
 66}
 67
 68#[derive(Debug, Clone)]
 69pub struct LocalKernelSpecification {
 70    pub name: String,
 71    pub path: PathBuf,
 72    pub kernelspec: JupyterKernelspec,
 73}
 74
 75impl PartialEq for LocalKernelSpecification {
 76    fn eq(&self, other: &Self) -> bool {
 77        self.name == other.name && self.path == other.path
 78    }
 79}
 80
 81impl Eq for LocalKernelSpecification {}
 82
 83#[derive(Debug, Clone)]
 84pub struct RemoteKernelSpecification {
 85    pub name: String,
 86    pub url: String,
 87    pub token: String,
 88    pub kernelspec: JupyterKernelspec,
 89}
 90
 91impl PartialEq for RemoteKernelSpecification {
 92    fn eq(&self, other: &Self) -> bool {
 93        self.name == other.name && self.url == other.url
 94    }
 95}
 96
 97impl Eq for RemoteKernelSpecification {}
 98
 99impl LocalKernelSpecification {
100    #[must_use]
101    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
102        let argv = &self.kernelspec.argv;
103
104        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
105        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
106        anyhow::ensure!(
107            argv.iter().any(|arg| arg == "{connection_file}"),
108            "Missing 'connection_file' in argv in kernelspec {}",
109            self.name
110        );
111
112        let mut cmd = Command::new(&argv[0]);
113
114        for arg in &argv[1..] {
115            if arg == "{connection_file}" {
116                cmd.arg(connection_path);
117            } else {
118                cmd.arg(arg);
119            }
120        }
121
122        if let Some(env) = &self.kernelspec.env {
123            cmd.envs(env);
124        }
125
126        #[cfg(windows)]
127        {
128            use smol::process::windows::CommandExt;
129            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
130        }
131
132        Ok(cmd)
133    }
134}
135
136// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
137// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
138async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
139    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
140    addr_zeroport.set_port(0);
141    let mut ports: [u16; 5] = [0; 5];
142    for i in 0..5 {
143        let listener = TcpListener::bind(addr_zeroport).await?;
144        let addr = listener.local_addr()?;
145        ports[i] = addr.port();
146    }
147    Ok(ports)
148}
149
150#[derive(Debug, Clone)]
151pub enum KernelStatus {
152    Idle,
153    Busy,
154    Starting,
155    Error,
156    ShuttingDown,
157    Shutdown,
158    Restarting,
159}
160
161impl KernelStatus {
162    pub fn is_connected(&self) -> bool {
163        match self {
164            KernelStatus::Idle | KernelStatus::Busy => true,
165            _ => false,
166        }
167    }
168}
169
170impl ToString for KernelStatus {
171    fn to_string(&self) -> String {
172        match self {
173            KernelStatus::Idle => "Idle".to_string(),
174            KernelStatus::Busy => "Busy".to_string(),
175            KernelStatus::Starting => "Starting".to_string(),
176            KernelStatus::Error => "Error".to_string(),
177            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
178            KernelStatus::Shutdown => "Shutdown".to_string(),
179            KernelStatus::Restarting => "Restarting".to_string(),
180        }
181    }
182}
183
184impl From<&Kernel> for KernelStatus {
185    fn from(kernel: &Kernel) -> Self {
186        match kernel {
187            Kernel::RunningKernel(kernel) => match kernel.execution_state {
188                ExecutionState::Idle => KernelStatus::Idle,
189                ExecutionState::Busy => KernelStatus::Busy,
190            },
191            Kernel::StartingKernel(_) => KernelStatus::Starting,
192            Kernel::ErroredLaunch(_) => KernelStatus::Error,
193            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
194            Kernel::Shutdown => KernelStatus::Shutdown,
195            Kernel::Restarting => KernelStatus::Restarting,
196        }
197    }
198}
199
200#[derive(Debug)]
201pub enum Kernel {
202    RunningKernel(RunningKernel),
203    StartingKernel(Shared<Task<()>>),
204    ErroredLaunch(String),
205    ShuttingDown,
206    Shutdown,
207    Restarting,
208}
209
210impl Kernel {
211    pub fn status(&self) -> KernelStatus {
212        self.into()
213    }
214
215    pub fn set_execution_state(&mut self, status: &ExecutionState) {
216        if let Kernel::RunningKernel(running_kernel) = self {
217            running_kernel.execution_state = status.clone();
218        }
219    }
220
221    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
222        if let Kernel::RunningKernel(running_kernel) = self {
223            running_kernel.kernel_info = Some(kernel_info.clone());
224        }
225    }
226
227    pub fn is_shutting_down(&self) -> bool {
228        match self {
229            Kernel::Restarting | Kernel::ShuttingDown => true,
230            Kernel::RunningKernel(_)
231            | Kernel::StartingKernel(_)
232            | Kernel::ErroredLaunch(_)
233            | Kernel::Shutdown => false,
234        }
235    }
236}
237
238pub struct RunningKernel {
239    pub process: smol::process::Child,
240    _shell_task: Task<Result<()>>,
241    _iopub_task: Task<Result<()>>,
242    _control_task: Task<Result<()>>,
243    _routing_task: Task<Result<()>>,
244    connection_path: PathBuf,
245    pub working_directory: PathBuf,
246    pub request_tx: mpsc::Sender<JupyterMessage>,
247    pub execution_state: ExecutionState,
248    pub kernel_info: Option<KernelInfoReply>,
249}
250
251type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
252
253impl Debug for RunningKernel {
254    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255        f.debug_struct("RunningKernel")
256            .field("process", &self.process)
257            .finish()
258    }
259}
260
261impl RunningKernel {
262    pub fn new(
263        kernel_specification: KernelSpecification,
264        entity_id: EntityId,
265        working_directory: PathBuf,
266        fs: Arc<dyn Fs>,
267        cx: &mut AppContext,
268    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
269        let kernel_specification = match kernel_specification {
270            KernelSpecification::Jupyter(spec) => spec,
271            KernelSpecification::PythonEnv(spec) => spec,
272            KernelSpecification::Remote(_spec) => {
273                // todo!(): Implement remote kernel specification
274                return Task::ready(Err(anyhow::anyhow!(
275                    "Running remote kernels is not supported"
276                )));
277            }
278        };
279
280        cx.spawn(|cx| async move {
281            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
282            let ports = peek_ports(ip).await?;
283
284            let connection_info = ConnectionInfo {
285                transport: "tcp".to_string(),
286                ip: ip.to_string(),
287                stdin_port: ports[0],
288                control_port: ports[1],
289                hb_port: ports[2],
290                shell_port: ports[3],
291                iopub_port: ports[4],
292                signature_scheme: "hmac-sha256".to_string(),
293                key: uuid::Uuid::new_v4().to_string(),
294                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
295            };
296
297            let runtime_dir = dirs::runtime_dir();
298            fs.create_dir(&runtime_dir)
299                .await
300                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
301            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
302            let content = serde_json::to_string(&connection_info)?;
303            fs.atomic_write(connection_path.clone(), content).await?;
304
305            let mut cmd = kernel_specification.command(&connection_path)?;
306
307            let process = cmd
308                .current_dir(&working_directory)
309                .stdout(std::process::Stdio::piped())
310                .stderr(std::process::Stdio::piped())
311                .stdin(std::process::Stdio::piped())
312                .kill_on_drop(true)
313                .spawn()
314                .context("failed to start the kernel process")?;
315
316            let session_id = Uuid::new_v4().to_string();
317
318            let mut iopub_socket = connection_info
319                .create_client_iopub_connection("", &session_id)
320                .await?;
321            let mut shell_socket = connection_info
322                .create_client_shell_connection(&session_id)
323                .await?;
324            let mut control_socket = connection_info
325                .create_client_control_connection(&session_id)
326                .await?;
327
328            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
329
330            let (request_tx, mut request_rx) =
331                futures::channel::mpsc::channel::<JupyterMessage>(100);
332
333            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
334            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
335
336            let mut messages_rx = SelectAll::new();
337            messages_rx.push(iosub);
338            messages_rx.push(control_reply_rx);
339            messages_rx.push(shell_reply_rx);
340
341            let iopub_task = cx.background_executor().spawn({
342                async move {
343                    while let Ok(message) = iopub_socket.read().await {
344                        iopub.send(message).await?;
345                    }
346                    anyhow::Ok(())
347                }
348            });
349
350            let (mut control_request_tx, mut control_request_rx) =
351                futures::channel::mpsc::channel(100);
352            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
353
354            let routing_task = cx.background_executor().spawn({
355                async move {
356                    while let Some(message) = request_rx.next().await {
357                        match message.content {
358                            JupyterMessageContent::DebugRequest(_)
359                            | JupyterMessageContent::InterruptRequest(_)
360                            | JupyterMessageContent::ShutdownRequest(_) => {
361                                control_request_tx.send(message).await?;
362                            }
363                            _ => {
364                                shell_request_tx.send(message).await?;
365                            }
366                        }
367                    }
368                    anyhow::Ok(())
369                }
370            });
371
372            let shell_task = cx.background_executor().spawn({
373                async move {
374                    while let Some(message) = shell_request_rx.next().await {
375                        shell_socket.send(message).await.ok();
376                        let reply = shell_socket.read().await?;
377                        shell_reply_tx.send(reply).await?;
378                    }
379                    anyhow::Ok(())
380                }
381            });
382
383            let control_task = cx.background_executor().spawn({
384                async move {
385                    while let Some(message) = control_request_rx.next().await {
386                        control_socket.send(message).await.ok();
387                        let reply = control_socket.read().await?;
388                        control_reply_tx.send(reply).await?;
389                    }
390                    anyhow::Ok(())
391                }
392            });
393
394            anyhow::Ok((
395                Self {
396                    process,
397                    request_tx,
398                    working_directory,
399                    _shell_task: shell_task,
400                    _iopub_task: iopub_task,
401                    _control_task: control_task,
402                    _routing_task: routing_task,
403                    connection_path,
404                    execution_state: ExecutionState::Idle,
405                    kernel_info: None,
406                },
407                messages_rx,
408            ))
409        })
410    }
411}
412
413impl Drop for RunningKernel {
414    fn drop(&mut self) {
415        std::fs::remove_file(&self.connection_path).ok();
416        self.request_tx.close_channel();
417        self.process.kill().ok();
418    }
419}
420
421async fn read_kernelspec_at(
422    // Path should be a directory to a jupyter kernelspec, as in
423    // /usr/local/share/jupyter/kernels/python3
424    kernel_dir: PathBuf,
425    fs: &dyn Fs,
426) -> Result<LocalKernelSpecification> {
427    let path = kernel_dir;
428    let kernel_name = if let Some(kernel_name) = path.file_name() {
429        kernel_name.to_string_lossy().to_string()
430    } else {
431        anyhow::bail!("Invalid kernelspec directory: {path:?}");
432    };
433
434    if !fs.is_dir(path.as_path()).await {
435        anyhow::bail!("Not a directory: {path:?}");
436    }
437
438    let expected_kernel_json = path.join("kernel.json");
439    let spec = fs.load(expected_kernel_json.as_path()).await?;
440    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
441
442    Ok(LocalKernelSpecification {
443        name: kernel_name,
444        path,
445        kernelspec: spec,
446    })
447}
448
449/// Read a directory of kernelspec directories
450async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
451    let mut kernelspec_dirs = fs.read_dir(&path).await?;
452
453    let mut valid_kernelspecs = Vec::new();
454    while let Some(path) = kernelspec_dirs.next().await {
455        match path {
456            Ok(path) => {
457                if fs.is_dir(path.as_path()).await {
458                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
459                        valid_kernelspecs.push(kernelspec);
460                    }
461                }
462            }
463            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
464        }
465    }
466
467    Ok(valid_kernelspecs)
468}
469
470pub fn python_env_kernel_specifications(
471    project: &Model<Project>,
472    worktree_id: WorktreeId,
473    cx: &mut AppContext,
474) -> impl Future<Output = Result<Vec<KernelSpecification>>> {
475    let python_language = LanguageName::new("Python");
476    let toolchains = project
477        .read(cx)
478        .available_toolchains(worktree_id, python_language, cx);
479    let background_executor = cx.background_executor().clone();
480
481    async move {
482        let toolchains = if let Some(toolchains) = toolchains.await {
483            toolchains
484        } else {
485            return Ok(Vec::new());
486        };
487
488        let kernelspecs = toolchains.toolchains.into_iter().map(|toolchain| {
489            background_executor.spawn(async move {
490                let python_path = toolchain.path.to_string();
491
492                // Check if ipykernel is installed
493                let ipykernel_check = Command::new(&python_path)
494                    .args(&["-c", "import ipykernel"])
495                    .output()
496                    .await;
497
498                if ipykernel_check.is_ok() && ipykernel_check.unwrap().status.success() {
499                    // Create a default kernelspec for this environment
500                    let default_kernelspec = JupyterKernelspec {
501                        argv: vec![
502                            python_path.clone(),
503                            "-m".to_string(),
504                            "ipykernel_launcher".to_string(),
505                            "-f".to_string(),
506                            "{connection_file}".to_string(),
507                        ],
508                        display_name: toolchain.name.to_string(),
509                        language: "python".to_string(),
510                        interrupt_mode: None,
511                        metadata: None,
512                        env: None,
513                    };
514
515                    Some(KernelSpecification::PythonEnv(LocalKernelSpecification {
516                        name: toolchain.name.to_string(),
517                        path: PathBuf::from(&python_path),
518                        kernelspec: default_kernelspec,
519                    }))
520                } else {
521                    None
522                }
523            })
524        });
525
526        let kernel_specs = futures::future::join_all(kernelspecs)
527            .await
528            .into_iter()
529            .flatten()
530            .collect();
531
532        anyhow::Ok(kernel_specs)
533    }
534}
535
536pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
537    let mut data_dirs = dirs::data_dirs();
538
539    // Pick up any kernels from conda or conda environment
540    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
541        let conda_prefix = PathBuf::from(conda_prefix);
542        let conda_data_dir = conda_prefix.join("share").join("jupyter");
543        data_dirs.push(conda_data_dir);
544    }
545
546    // Search for kernels inside the base python environment
547    let mut command = Command::new("python");
548    command.arg("-c");
549    command.arg("import sys; print(sys.prefix)");
550
551    #[cfg(windows)]
552    {
553        use smol::process::windows::CommandExt;
554        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
555    }
556
557    let command = command.output().await;
558
559    if let Ok(command) = command {
560        if command.status.success() {
561            let python_prefix = String::from_utf8(command.stdout);
562            if let Ok(python_prefix) = python_prefix {
563                let python_prefix = PathBuf::from(python_prefix.trim());
564                let python_data_dir = python_prefix.join("share").join("jupyter");
565                data_dirs.push(python_data_dir);
566            }
567        }
568    }
569
570    let kernel_dirs = data_dirs
571        .iter()
572        .map(|dir| dir.join("kernels"))
573        .map(|path| read_kernels_dir(path, fs.as_ref()))
574        .collect::<Vec<_>>();
575
576    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
577    let kernel_dirs = kernel_dirs
578        .into_iter()
579        .filter_map(Result::ok)
580        .flatten()
581        .collect::<Vec<_>>();
582
583    Ok(kernel_dirs)
584}
585
586#[cfg(test)]
587mod test {
588    use super::*;
589    use std::path::PathBuf;
590
591    use gpui::TestAppContext;
592    use project::FakeFs;
593    use serde_json::json;
594
595    #[gpui::test]
596    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
597        let fs = FakeFs::new(cx.executor());
598        fs.insert_tree(
599            "/jupyter",
600            json!({
601                ".zed": {
602                    "settings.json": r#"{ "tab_size": 8 }"#,
603                    "tasks.json": r#"[{
604                        "label": "cargo check",
605                        "command": "cargo",
606                        "args": ["check", "--all"]
607                    },]"#,
608                },
609                "kernels": {
610                    "python": {
611                        "kernel.json": r#"{
612                            "display_name": "Python 3",
613                            "language": "python",
614                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
615                            "env": {}
616                        }"#
617                    },
618                    "deno": {
619                        "kernel.json": r#"{
620                            "display_name": "Deno",
621                            "language": "typescript",
622                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
623                            "env": {}
624                        }"#
625                    }
626                },
627            }),
628        )
629        .await;
630
631        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
632            .await
633            .unwrap();
634
635        kernels.sort_by(|a, b| a.name.cmp(&b.name));
636
637        assert_eq!(
638            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
639            vec!["deno", "python"]
640        );
641    }
642}