1use anyhow::{Context as _, Result};
2use futures::{
3 channel::mpsc::{self},
4 stream::{SelectAll, StreamExt},
5 SinkExt as _,
6};
7use gpui::{AppContext, EntityId, Task};
8use jupyter_protocol::{JupyterMessage, JupyterMessageContent, KernelInfoReply};
9use project::Fs;
10use runtimelib::{dirs, ConnectionInfo, ExecutionState, JupyterKernelspec};
11use smol::{net::TcpListener, process::Command};
12use std::{
13 env,
14 fmt::Debug,
15 net::{IpAddr, Ipv4Addr, SocketAddr},
16 path::PathBuf,
17 sync::Arc,
18};
19use uuid::Uuid;
20
21use super::{JupyterMessageChannel, RunningKernel};
22
23#[derive(Debug, Clone)]
24pub struct LocalKernelSpecification {
25 pub name: String,
26 pub path: PathBuf,
27 pub kernelspec: JupyterKernelspec,
28}
29
30impl PartialEq for LocalKernelSpecification {
31 fn eq(&self, other: &Self) -> bool {
32 self.name == other.name && self.path == other.path
33 }
34}
35
36impl Eq for LocalKernelSpecification {}
37
38impl LocalKernelSpecification {
39 #[must_use]
40 fn command(&self, connection_path: &PathBuf) -> Result<Command> {
41 let argv = &self.kernelspec.argv;
42
43 anyhow::ensure!(!argv.is_empty(), "Empty argv in kernelspec {}", self.name);
44 anyhow::ensure!(argv.len() >= 2, "Invalid argv in kernelspec {}", self.name);
45 anyhow::ensure!(
46 argv.iter().any(|arg| arg == "{connection_file}"),
47 "Missing 'connection_file' in argv in kernelspec {}",
48 self.name
49 );
50
51 let mut cmd = Command::new(&argv[0]);
52
53 for arg in &argv[1..] {
54 if arg == "{connection_file}" {
55 cmd.arg(connection_path);
56 } else {
57 cmd.arg(arg);
58 }
59 }
60
61 if let Some(env) = &self.kernelspec.env {
62 cmd.envs(env);
63 }
64
65 #[cfg(windows)]
66 {
67 use smol::process::windows::CommandExt;
68 cmd.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
69 }
70
71 Ok(cmd)
72 }
73}
74
75// Find a set of open ports. This creates a listener with port set to 0. The listener will be closed at the end when it goes out of scope.
76// There's a race condition between closing the ports and usage by a kernel, but it's inherent to the Jupyter protocol.
77async fn peek_ports(ip: IpAddr) -> Result<[u16; 5]> {
78 let mut addr_zeroport: SocketAddr = SocketAddr::new(ip, 0);
79 addr_zeroport.set_port(0);
80 let mut ports: [u16; 5] = [0; 5];
81 for i in 0..5 {
82 let listener = TcpListener::bind(addr_zeroport).await?;
83 let addr = listener.local_addr()?;
84 ports[i] = addr.port();
85 }
86 Ok(ports)
87}
88
89pub struct NativeRunningKernel {
90 pub process: smol::process::Child,
91 _shell_task: Task<Result<()>>,
92 _iopub_task: Task<Result<()>>,
93 _control_task: Task<Result<()>>,
94 _routing_task: Task<Result<()>>,
95 connection_path: PathBuf,
96 pub working_directory: PathBuf,
97 pub request_tx: mpsc::Sender<JupyterMessage>,
98 pub execution_state: ExecutionState,
99 pub kernel_info: Option<KernelInfoReply>,
100}
101
102impl Debug for NativeRunningKernel {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("RunningKernel")
105 .field("process", &self.process)
106 .finish()
107 }
108}
109
110impl NativeRunningKernel {
111 pub fn new(
112 kernel_specification: LocalKernelSpecification,
113 entity_id: EntityId,
114 working_directory: PathBuf,
115 fs: Arc<dyn Fs>,
116 cx: &mut AppContext,
117 ) -> Task<Result<(Self, JupyterMessageChannel)>> {
118 cx.spawn(|cx| async move {
119 let ip = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
120 let ports = peek_ports(ip).await?;
121
122 let connection_info = ConnectionInfo {
123 transport: "tcp".to_string(),
124 ip: ip.to_string(),
125 stdin_port: ports[0],
126 control_port: ports[1],
127 hb_port: ports[2],
128 shell_port: ports[3],
129 iopub_port: ports[4],
130 signature_scheme: "hmac-sha256".to_string(),
131 key: uuid::Uuid::new_v4().to_string(),
132 kernel_name: Some(format!("zed-{}", kernel_specification.name)),
133 };
134
135 let runtime_dir = dirs::runtime_dir();
136 fs.create_dir(&runtime_dir)
137 .await
138 .with_context(|| format!("Failed to create jupyter runtime dir {runtime_dir:?}"))?;
139 let connection_path = runtime_dir.join(format!("kernel-zed-{entity_id}.json"));
140 let content = serde_json::to_string(&connection_info)?;
141 fs.atomic_write(connection_path.clone(), content).await?;
142
143 let mut cmd = kernel_specification.command(&connection_path)?;
144
145 let process = cmd
146 .current_dir(&working_directory)
147 .stdout(std::process::Stdio::piped())
148 .stderr(std::process::Stdio::piped())
149 .stdin(std::process::Stdio::piped())
150 .kill_on_drop(true)
151 .spawn()
152 .context("failed to start the kernel process")?;
153
154 let session_id = Uuid::new_v4().to_string();
155
156 let mut iopub_socket =
157 runtimelib::create_client_iopub_connection(&connection_info, "", &session_id)
158 .await?;
159 let mut shell_socket =
160 runtimelib::create_client_shell_connection(&connection_info, &session_id).await?;
161 let mut control_socket =
162 runtimelib::create_client_control_connection(&connection_info, &session_id).await?;
163
164 let (mut iopub, iosub) = futures::channel::mpsc::channel(100);
165
166 let (request_tx, mut request_rx) =
167 futures::channel::mpsc::channel::<JupyterMessage>(100);
168
169 let (mut control_reply_tx, control_reply_rx) = futures::channel::mpsc::channel(100);
170 let (mut shell_reply_tx, shell_reply_rx) = futures::channel::mpsc::channel(100);
171
172 let mut messages_rx = SelectAll::new();
173 messages_rx.push(iosub);
174 messages_rx.push(control_reply_rx);
175 messages_rx.push(shell_reply_rx);
176
177 let iopub_task = cx.background_executor().spawn({
178 async move {
179 while let Ok(message) = iopub_socket.read().await {
180 iopub.send(message).await?;
181 }
182 anyhow::Ok(())
183 }
184 });
185
186 let (mut control_request_tx, mut control_request_rx) =
187 futures::channel::mpsc::channel(100);
188 let (mut shell_request_tx, mut shell_request_rx) = futures::channel::mpsc::channel(100);
189
190 let routing_task = cx.background_executor().spawn({
191 async move {
192 while let Some(message) = request_rx.next().await {
193 match message.content {
194 JupyterMessageContent::DebugRequest(_)
195 | JupyterMessageContent::InterruptRequest(_)
196 | JupyterMessageContent::ShutdownRequest(_) => {
197 control_request_tx.send(message).await?;
198 }
199 _ => {
200 shell_request_tx.send(message).await?;
201 }
202 }
203 }
204 anyhow::Ok(())
205 }
206 });
207
208 let shell_task = cx.background_executor().spawn({
209 async move {
210 while let Some(message) = shell_request_rx.next().await {
211 shell_socket.send(message).await.ok();
212 let reply = shell_socket.read().await?;
213 shell_reply_tx.send(reply).await?;
214 }
215 anyhow::Ok(())
216 }
217 });
218
219 let control_task = cx.background_executor().spawn({
220 async move {
221 while let Some(message) = control_request_rx.next().await {
222 control_socket.send(message).await.ok();
223 let reply = control_socket.read().await?;
224 control_reply_tx.send(reply).await?;
225 }
226 anyhow::Ok(())
227 }
228 });
229
230 anyhow::Ok((
231 Self {
232 process,
233 request_tx,
234 working_directory,
235 _shell_task: shell_task,
236 _iopub_task: iopub_task,
237 _control_task: control_task,
238 _routing_task: routing_task,
239 connection_path,
240 execution_state: ExecutionState::Idle,
241 kernel_info: None,
242 },
243 messages_rx,
244 ))
245 })
246 }
247}
248
249impl RunningKernel for NativeRunningKernel {
250 fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
251 self.request_tx.clone()
252 }
253
254 fn working_directory(&self) -> &PathBuf {
255 &self.working_directory
256 }
257
258 fn execution_state(&self) -> &ExecutionState {
259 &self.execution_state
260 }
261
262 fn set_execution_state(&mut self, state: ExecutionState) {
263 self.execution_state = state;
264 }
265
266 fn kernel_info(&self) -> Option<&KernelInfoReply> {
267 self.kernel_info.as_ref()
268 }
269
270 fn set_kernel_info(&mut self, info: KernelInfoReply) {
271 self.kernel_info = Some(info);
272 }
273
274 fn force_shutdown(&mut self) -> anyhow::Result<()> {
275 match self.process.kill() {
276 Ok(_) => Ok(()),
277 Err(error) => Err(anyhow::anyhow!(
278 "Failed to kill the kernel process: {}",
279 error
280 )),
281 }
282 }
283}
284
285impl Drop for NativeRunningKernel {
286 fn drop(&mut self) {
287 std::fs::remove_file(&self.connection_path).ok();
288 self.request_tx.close_channel();
289 self.process.kill().ok();
290 }
291}
292
293async fn read_kernelspec_at(
294 // Path should be a directory to a jupyter kernelspec, as in
295 // /usr/local/share/jupyter/kernels/python3
296 kernel_dir: PathBuf,
297 fs: &dyn Fs,
298) -> Result<LocalKernelSpecification> {
299 let path = kernel_dir;
300 let kernel_name = if let Some(kernel_name) = path.file_name() {
301 kernel_name.to_string_lossy().to_string()
302 } else {
303 anyhow::bail!("Invalid kernelspec directory: {path:?}");
304 };
305
306 if !fs.is_dir(path.as_path()).await {
307 anyhow::bail!("Not a directory: {path:?}");
308 }
309
310 let expected_kernel_json = path.join("kernel.json");
311 let spec = fs.load(expected_kernel_json.as_path()).await?;
312 let spec = serde_json::from_str::<JupyterKernelspec>(&spec)?;
313
314 Ok(LocalKernelSpecification {
315 name: kernel_name,
316 path,
317 kernelspec: spec,
318 })
319}
320
321/// Read a directory of kernelspec directories
322async fn read_kernels_dir(path: PathBuf, fs: &dyn Fs) -> Result<Vec<LocalKernelSpecification>> {
323 let mut kernelspec_dirs = fs.read_dir(&path).await?;
324
325 let mut valid_kernelspecs = Vec::new();
326 while let Some(path) = kernelspec_dirs.next().await {
327 match path {
328 Ok(path) => {
329 if fs.is_dir(path.as_path()).await {
330 if let Ok(kernelspec) = read_kernelspec_at(path, fs).await {
331 valid_kernelspecs.push(kernelspec);
332 }
333 }
334 }
335 Err(err) => log::warn!("Error reading kernelspec directory: {err:?}"),
336 }
337 }
338
339 Ok(valid_kernelspecs)
340}
341
342pub async fn local_kernel_specifications(fs: Arc<dyn Fs>) -> Result<Vec<LocalKernelSpecification>> {
343 let mut data_dirs = dirs::data_dirs();
344
345 // Pick up any kernels from conda or conda environment
346 if let Ok(conda_prefix) = env::var("CONDA_PREFIX") {
347 let conda_prefix = PathBuf::from(conda_prefix);
348 let conda_data_dir = conda_prefix.join("share").join("jupyter");
349 data_dirs.push(conda_data_dir);
350 }
351
352 // Search for kernels inside the base python environment
353 let mut command = Command::new("python");
354 command.arg("-c");
355 command.arg("import sys; print(sys.prefix)");
356
357 #[cfg(windows)]
358 {
359 use smol::process::windows::CommandExt;
360 command.creation_flags(windows::Win32::System::Threading::CREATE_NO_WINDOW.0);
361 }
362
363 let command = command.output().await;
364
365 if let Ok(command) = command {
366 if command.status.success() {
367 let python_prefix = String::from_utf8(command.stdout);
368 if let Ok(python_prefix) = python_prefix {
369 let python_prefix = PathBuf::from(python_prefix.trim());
370 let python_data_dir = python_prefix.join("share").join("jupyter");
371 data_dirs.push(python_data_dir);
372 }
373 }
374 }
375
376 let kernel_dirs = data_dirs
377 .iter()
378 .map(|dir| dir.join("kernels"))
379 .map(|path| read_kernels_dir(path, fs.as_ref()))
380 .collect::<Vec<_>>();
381
382 let kernel_dirs = futures::future::join_all(kernel_dirs).await;
383 let kernel_dirs = kernel_dirs
384 .into_iter()
385 .filter_map(Result::ok)
386 .flatten()
387 .collect::<Vec<_>>();
388
389 Ok(kernel_dirs)
390}
391
392#[cfg(test)]
393mod test {
394 use super::*;
395 use std::path::PathBuf;
396
397 use gpui::TestAppContext;
398 use project::FakeFs;
399 use serde_json::json;
400
401 #[gpui::test]
402 async fn test_get_kernelspecs(cx: &mut TestAppContext) {
403 let fs = FakeFs::new(cx.executor());
404 fs.insert_tree(
405 "/jupyter",
406 json!({
407 ".zed": {
408 "settings.json": r#"{ "tab_size": 8 }"#,
409 "tasks.json": r#"[{
410 "label": "cargo check",
411 "command": "cargo",
412 "args": ["check", "--all"]
413 },]"#,
414 },
415 "kernels": {
416 "python": {
417 "kernel.json": r#"{
418 "display_name": "Python 3",
419 "language": "python",
420 "argv": ["python3", "-m", "ipykernel_launcher", "-f", "{connection_file}"],
421 "env": {}
422 }"#
423 },
424 "deno": {
425 "kernel.json": r#"{
426 "display_name": "Deno",
427 "language": "typescript",
428 "argv": ["deno", "run", "--unstable", "--allow-net", "--allow-read", "https://deno.land/std/http/file_server.ts", "{connection_file}"],
429 "env": {}
430 }"#
431 }
432 },
433 }),
434 )
435 .await;
436
437 let mut kernels = read_kernels_dir(PathBuf::from("/jupyter/kernels"), fs.as_ref())
438 .await
439 .unwrap();
440
441 kernels.sort_by(|a, b| a.name.cmp(&b.name));
442
443 assert_eq!(
444 kernels.iter().map(|c| c.name.clone()).collect::<Vec<_>>(),
445 vec!["deno", "python"]
446 );
447 }
448}