kernels.rs

  1use anyhow::{Context as _, Result};
  2use futures::{
  3    channel::mpsc::{self, Receiver},
  4    future::Shared,
  5    stream::{self, SelectAll, StreamExt},
  6    SinkExt as _,
  7};
  8use gpui::{AppContext, EntityId, Task};
  9use project::Fs;
 10use runtimelib::{
 11    dirs, ConnectionInfo, ExecutionState, JupyterKernelspec, JupyterMessage, JupyterMessageContent,
 12    KernelInfoReply,
 13};
 14use smol::{net::TcpListener, process::Command};
 15use std::{
 16    env,
 17    fmt::Debug,
 18    net::{IpAddr, Ipv4Addr, SocketAddr},
 19    path::PathBuf,
 20    sync::Arc,
 21};
 22use ui::SharedString;
 23use uuid::Uuid;
 24
 25#[derive(Debug, Clone, PartialEq, Eq)]
 26pub enum KernelSpecification {
 27    Remote(RemoteKernelSpecification),
 28    Jupyter(LocalKernelSpecification),
 29    PythonEnv(LocalKernelSpecification),
 30}
 31
 32impl KernelSpecification {
 33    pub fn name(&self) -> SharedString {
 34        match self {
 35            Self::Jupyter(spec) => spec.name.clone().into(),
 36            Self::PythonEnv(spec) => spec.name.clone().into(),
 37            Self::Remote(spec) => spec.name.clone().into(),
 38        }
 39    }
 40
 41    pub fn type_name(&self) -> SharedString {
 42        match self {
 43            Self::Jupyter(_) => "Jupyter".into(),
 44            Self::PythonEnv(_) => "Python Environment".into(),
 45            Self::Remote(_) => "Remote".into(),
 46        }
 47    }
 48
 49    pub fn path(&self) -> SharedString {
 50        SharedString::from(match self {
 51            Self::Jupyter(spec) => spec.path.to_string_lossy().to_string(),
 52            Self::PythonEnv(spec) => spec.path.to_string_lossy().to_string(),
 53            Self::Remote(spec) => spec.url.to_string(),
 54        })
 55    }
 56
 57    pub fn language(&self) -> SharedString {
 58        SharedString::from(match self {
 59            Self::Jupyter(spec) => spec.kernelspec.language.clone(),
 60            Self::PythonEnv(spec) => spec.kernelspec.language.clone(),
 61            Self::Remote(spec) => spec.kernelspec.language.clone(),
 62        })
 63    }
 64}
 65
 66#[derive(Debug, Clone)]
 67pub struct LocalKernelSpecification {
 68    pub name: String,
 69    pub path: PathBuf,
 70    pub kernelspec: JupyterKernelspec,
 71}
 72
 73impl PartialEq for LocalKernelSpecification {
 74    fn eq(&self, other: &Self) -> bool {
 75        self.name == other.name && self.path == other.path
 76    }
 77}
 78
 79impl Eq for LocalKernelSpecification {}
 80
 81#[derive(Debug, Clone)]
 82pub struct RemoteKernelSpecification {
 83    pub name: String,
 84    pub url: String,
 85    pub token: String,
 86    pub kernelspec: JupyterKernelspec,
 87}
 88
 89impl PartialEq for RemoteKernelSpecification {
 90    fn eq(&self, other: &Self) -> bool {
 91        self.name == other.name && self.url == other.url
 92    }
 93}
 94
 95impl Eq for RemoteKernelSpecification {}
 96
 97impl LocalKernelSpecification {
 98    #[must_use]
 99    fn command(&self, connection_path: &PathBuf) -> Result<Command> {
100        let argv = &self.kernelspec.argv;
101
102        anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
103        anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
104        anyhow::ensure!(
105            argv.iter().any(|arg| arg == "{connection_file}"),
106            "Missing 'connection_file' in argv in kernelspec {}",
107            self.name
108        );
109
110        let mut cmd = Command::new(&argv[0]);
111
112        for arg in &argv[1..] {
113            if arg == "{connection_file}" {
114                cmd.arg(connection_path);
115            } else {
116                cmd.arg(arg);
117            }
118        }
119
120        if let Some(env) = &self.kernelspec.env {
121            cmd.envs(env);
122        }
123
124        #[cfg(windows)]
125        {
126            use smol::process::windows::CommandExt;
127            cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
128        }
129
130        Ok(cmd)
131    }
132}
133
134// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
135// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
136async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
137    let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
138    addr_zeroport.set_port(0);
139    let mut ports: [u16; 5] = [0; 5];
140    for i in 0..5 {
141        let listener = TcpListener::bind(addr_zeroport).await?;
142        let addr = listener.local_addr()?;
143        ports[i] = addr.port();
144    }
145    Ok(ports)
146}
147
148#[derive(Debug, Clone)]
149pub enum KernelStatus {
150    Idle,
151    Busy,
152    Starting,
153    Error,
154    ShuttingDown,
155    Shutdown,
156    Restarting,
157}
158
159impl KernelStatus {
160    pub fn is_connected(&self) -> bool {
161        match self {
162            KernelStatus::Idle | KernelStatus::Busy => true,
163            _ => false,
164        }
165    }
166}
167
168impl ToString for KernelStatus {
169    fn to_string(&self) -> String {
170        match self {
171            KernelStatus::Idle => "Idle".to_string(),
172            KernelStatus::Busy => "Busy".to_string(),
173            KernelStatus::Starting => "Starting".to_string(),
174            KernelStatus::Error => "Error".to_string(),
175            KernelStatus::ShuttingDown => "Shutting Down".to_string(),
176            KernelStatus::Shutdown => "Shutdown".to_string(),
177            KernelStatus::Restarting => "Restarting".to_string(),
178        }
179    }
180}
181
182impl From<&Kernel> for KernelStatus {
183    fn from(kernel: &Kernel) -> Self {
184        match kernel {
185            Kernel::RunningKernel(kernel) => match kernel.execution_state {
186                ExecutionState::Idle => KernelStatus::Idle,
187                ExecutionState::Busy => KernelStatus::Busy,
188            },
189            Kernel::StartingKernel(_) => KernelStatus::Starting,
190            Kernel::ErroredLaunch(_) => KernelStatus::Error,
191            Kernel::ShuttingDown => KernelStatus::ShuttingDown,
192            Kernel::Shutdown => KernelStatus::Shutdown,
193            Kernel::Restarting => KernelStatus::Restarting,
194        }
195    }
196}
197
198#[derive(Debug)]
199pub enum Kernel {
200    RunningKernel(RunningKernel),
201    StartingKernel(Shared<Task<()>>),
202    ErroredLaunch(String),
203    ShuttingDown,
204    Shutdown,
205    Restarting,
206}
207
208impl Kernel {
209    pub fn status(&self) -> KernelStatus {
210        self.into()
211    }
212
213    pub fn set_execution_state(&mut self, status: &ExecutionState) {
214        if let Kernel::RunningKernel(running_kernel) = self {
215            running_kernel.execution_state = status.clone();
216        }
217    }
218
219    pub fn set_kernel_info(&mut self, kernel_info: &KernelInfoReply) {
220        if let Kernel::RunningKernel(running_kernel) = self {
221            running_kernel.kernel_info = Some(kernel_info.clone());
222        }
223    }
224
225    pub fn is_shutting_down(&self) -> bool {
226        match self {
227            Kernel::Restarting | Kernel::ShuttingDown => true,
228            Kernel::RunningKernel(_)
229            | Kernel::StartingKernel(_)
230            | Kernel::ErroredLaunch(_)
231            | Kernel::Shutdown => false,
232        }
233    }
234}
235
236pub struct RunningKernel {
237    pub process: smol::process::Child,
238    _shell_task: Task<Result<()>>,
239    _iopub_task: Task<Result<()>>,
240    _control_task: Task<Result<()>>,
241    _routing_task: Task<Result<()>>,
242    connection_path: PathBuf,
243    pub working_directory: PathBuf,
244    pub request_tx: mpsc::Sender<JupyterMessage>,
245    pub execution_state: ExecutionState,
246    pub kernel_info: Option<KernelInfoReply>,
247}
248
249type JupyterMessageChannel = stream::SelectAll<Receiver<JupyterMessage>>;
250
251impl Debug for RunningKernel {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("RunningKernel")
254            .field("process", &self.process)
255            .finish()
256    }
257}
258
259impl RunningKernel {
260    pub fn new(
261        kernel_specification: KernelSpecification,
262        entity_id: EntityId,
263        working_directory: PathBuf,
264        fs: Arc<dyn Fs>,
265        cx: &mut AppContext,
266    ) -> Task<Result<(Self, JupyterMessageChannel)>> {
267        let kernel_specification = match kernel_specification {
268            KernelSpecification::Jupyter(spec) => spec,
269            KernelSpecification::PythonEnv(spec) => spec,
270            KernelSpecification::Remote(_spec) => {
271                // todo!(): Implement remote kernel specification
272                return Task::ready(Err(anyhow::anyhow!(
273                    "Running remote kernels is not supported"
274                )));
275            }
276        };
277
278        cx.spawn(|cx| async move {
279            let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
280            let ports = peek_ports(ip).await?;
281
282            let connection_info = ConnectionInfo {
283                transport: "tcp".to_string(),
284                ip: ip.to_string(),
285                stdin_port: ports[0],
286                control_port: ports[1],
287                hb_port: ports[2],
288                shell_port: ports[3],
289                iopub_port: ports[4],
290                signature_scheme: "hmac-sha256".to_string(),
291                key: uuid::Uuid::new_v4().to_string(),
292                kernel_name: Some(format!("zed-{}", kernel_specification.name)),
293            };
294
295            let runtime_dir = dirs::runtime_dir();
296            fs.create_dir(&runtime_dir)
297                .await
298                .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
299            let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
300            let content = serde_json::to_string(&connection_info)?;
301            fs.atomic_write(connection_path.clone(), content).await?;
302
303            let mut cmd = kernel_specification.command(&connection_path)?;
304
305            let process = cmd
306                .current_dir(&working_directory)
307                .stdout(std::process::Stdio::piped())
308                .stderr(std::process::Stdio::piped())
309                .stdin(std::process::Stdio::piped())
310                .kill_on_drop(true)
311                .spawn()
312                .context("failed to start the kernel process")?;
313
314            let session_id = Uuid::new_v4().to_string();
315
316            let mut iopub_socket = connection_info
317                .create_client_iopub_connection("", &session_id)
318                .await?;
319            let mut shell_socket = connection_info
320                .create_client_shell_connection(&session_id)
321                .await?;
322            let mut control_socket = connection_info
323                .create_client_control_connection(&session_id)
324                .await?;
325
326            let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
327
328            let (request_tx, mut request_rx) =
329                futures::channel::mpsc::channel::<JupyterMessage>(100);
330
331            let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
332            let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
333
334            let mut messages_rx = SelectAll::new();
335            messages_rx.push(iosub);
336            messages_rx.push(control_reply_rx);
337            messages_rx.push(shell_reply_rx);
338
339            let iopub_task = cx.background_executor().spawn({
340                async move {
341                    while let Ok(message) = iopub_socket.read().await {
342                        iopub.send(message).await?;
343                    }
344                    anyhow::Ok(())
345                }
346            });
347
348            let (mut control_request_tx, mut control_request_rx) =
349                futures::channel::mpsc::channel(100);
350            let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
351
352            let routing_task = cx.background_executor().spawn({
353                async move {
354                    while let Some(message) = request_rx.next().await {
355                        match message.content {
356                            JupyterMessageContent::DebugRequest(_)
357                            | JupyterMessageContent::InterruptRequest(_)
358                            | JupyterMessageContent::ShutdownRequest(_) => {
359                                control_request_tx.send(message).await?;
360                            }
361                            _ => {
362                                shell_request_tx.send(message).await?;
363                            }
364                        }
365                    }
366                    anyhow::Ok(())
367                }
368            });
369
370            let shell_task = cx.background_executor().spawn({
371                async move {
372                    while let Some(message) = shell_request_rx.next().await {
373                        shell_socket.send(message).await.ok();
374                        let reply = shell_socket.read().await?;
375                        shell_reply_tx.send(reply).await?;
376                    }
377                    anyhow::Ok(())
378                }
379            });
380
381            let control_task = cx.background_executor().spawn({
382                async move {
383                    while let Some(message) = control_request_rx.next().await {
384                        control_socket.send(message).await.ok();
385                        let reply = control_socket.read().await?;
386                        control_reply_tx.send(reply).await?;
387                    }
388                    anyhow::Ok(())
389                }
390            });
391
392            anyhow::Ok((
393                Self {
394                    process,
395                    request_tx,
396                    working_directory,
397                    _shell_task: shell_task,
398                    _iopub_task: iopub_task,
399                    _control_task: control_task,
400                    _routing_task: routing_task,
401                    connection_path,
402                    execution_state: ExecutionState::Idle,
403                    kernel_info: None,
404                },
405                messages_rx,
406            ))
407        })
408    }
409}
410
411impl Drop for RunningKernel {
412    fn drop(&mut self) {
413        std::fs::remove_file(&self.connection_path).ok();
414        self.request_tx.close_channel();
415        self.process.kill().ok();
416    }
417}
418
419async fn read_kernelspec_at(
420    // Path should be a directory to a jupyter kernelspec, as in
421    // /usr/local/share/jupyter/kernels/python3
422    kernel_dir: PathBuf,
423    fs: &dyn Fs,
424) -> Result<LocalKernelSpecification> {
425    let path = kernel_dir;
426    let kernel_name = if let Some(kernel_name) = path.file_name() {
427        kernel_name.to_string_lossy().to_string()
428    } else {
429        anyhow::bail!("Invalid kernelspec directory: {path:?}");
430    };
431
432    if !fs.is_dir(path.as_path()).await {
433        anyhow::bail!("Not a directory: {path:?}");
434    }
435
436    let expected_kernel_json = path.join("kernel.json");
437    let spec = fs.load(expected_kernel_json.as_path()).await?;
438    let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
439
440    Ok(LocalKernelSpecification {
441        name: kernel_name,
442        path,
443        kernelspec: spec,
444    })
445}
446
447/// Read a directory of kernelspec directories
448async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
449    let mut kernelspec_dirs = fs.read_dir(&path).await?;
450
451    let mut valid_kernelspecs = Vec::new();
452    while let Some(path) = kernelspec_dirs.next().await {
453        match path {
454            Ok(path) => {
455                if fs.is_dir(path.as_path()).await {
456                    if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
457                        valid_kernelspecs.push(kernelspec);
458                    }
459                }
460            }
461            Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
462        }
463    }
464
465    Ok(valid_kernelspecs)
466}
467
468pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
469    let mut data_dirs = dirs::data_dirs();
470
471    // Pick up any kernels from conda or conda environment
472    if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
473        let conda_prefix = PathBuf::from(conda_prefix);
474        let conda_data_dir = conda_prefix.join("share").join("jupyter");
475        data_dirs.push(conda_data_dir);
476    }
477
478    // Search for kernels inside the base python environment
479    let mut command = Command::new("python");
480    command.arg("-c");
481    command.arg("import sys; print(sys.prefix)");
482
483    #[cfg(windows)]
484    {
485        use smol::process::windows::CommandExt;
486        command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
487    }
488
489    let command = command.output().await;
490
491    if let Ok(command) = command {
492        if command.status.success() {
493            let python_prefix = String::from_utf8(command.stdout);
494            if let Ok(python_prefix) = python_prefix {
495                let python_prefix = PathBuf::from(python_prefix.trim());
496                let python_data_dir = python_prefix.join("share").join("jupyter");
497                data_dirs.push(python_data_dir);
498            }
499        }
500    }
501
502    let kernel_dirs = data_dirs
503        .iter()
504        .map(|dir| dir.join("kernels"))
505        .map(|path| read_kernels_dir(path, fs.as_ref()))
506        .collect::<Vec<_>>();
507
508    let kernel_dirs = futures::future::join_all(kernel_dirs).await;
509    let kernel_dirs = kernel_dirs
510        .into_iter()
511        .filter_map(Result::ok)
512        .flatten()
513        .collect::<Vec<_>>();
514
515    Ok(kernel_dirs)
516}
517
518#[cfg(test)]
519mod test {
520    use super::*;
521    use std::path::PathBuf;
522
523    use gpui::TestAppContext;
524    use project::FakeFs;
525    use serde_json::json;
526
527    #[gpui::test]
528    async fn test_get_kernelspecs(cx: &mut TestAppContext) {
529        let fs = FakeFs::new(cx.executor());
530        fs.insert_tree(
531            "/jupyter",
532            json!({
533                ".zed": {
534                    "settings.json": r#"{ "tab_size": 8 }"#,
535                    "tasks.json": r#"[{
536                        "label": "cargo check",
537                        "command": "cargo",
538                        "args": ["check", "--all"]
539                    },]"#,
540                },
541                "kernels": {
542                    "python": {
543                        "kernel.json": r#"{
544                            "display_name": "Python 3",
545                            "language": "python",
546                            "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
547                            "env": {}
548                        }"#
549                    },
550                    "deno": {
551                        "kernel.json": r#"{
552                            "display_name": "Deno",
553                            "language": "typescript",
554                            "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
555                            "env": {}
556                        }"#
557                    }
558                },
559            }),
560        )
561        .await;
562
563        let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
564            .await
565            .unwrap();
566
567        kernels.sort_by(|a, b| a.name.cmp(&b.name));
568
569        assert_eq!(
570            kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
571            vec!["deno", "python"]
572        );
573    }
574}