ssh_kernel.rs

  1use super::{KernelSession, RunningKernel, SshRemoteKernelSpecification, start_kernel_tasks};
  2use anyhow::{Context as _, Result};
  3use client::proto;
  4
  5use futures::{
  6    AsyncBufReadExt as _, StreamExt as _,
  7    channel::mpsc::{self},
  8    io::BufReader,
  9};
 10use gpui::{App, Entity, Task, Window};
 11use project::Project;
 12use runtimelib::{ExecutionState, JupyterMessage, KernelInfoReply};
 13use std::path::PathBuf;
 14use util::ResultExt;
 15
 16#[derive(Debug)]
 17pub struct SshRunningKernel {
 18    request_tx: mpsc::Sender<JupyterMessage>,
 19    stdin_tx: mpsc::Sender<JupyterMessage>,
 20    execution_state: ExecutionState,
 21    kernel_info: Option<KernelInfoReply>,
 22    working_directory: PathBuf,
 23    _ssh_tunnel_process: util::command::Child,
 24    _local_connection_file: PathBuf,
 25    kernel_id: String,
 26    project: Entity<Project>,
 27    project_id: u64,
 28}
 29
 30impl SshRunningKernel {
 31    pub fn new<S: KernelSession + 'static>(
 32        kernel_spec: SshRemoteKernelSpecification,
 33        working_directory: PathBuf,
 34        project: Entity<Project>,
 35        session: Entity<S>,
 36        window: &mut Window,
 37        cx: &mut App,
 38    ) -> Task<Result<Box<dyn RunningKernel>>> {
 39        let client = project.read(cx).client();
 40        let remote_client = project.read(cx).remote_client();
 41        let project_id = project
 42            .read(cx)
 43            .remote_id()
 44            .unwrap_or(proto::REMOTE_SERVER_PROJECT_ID);
 45
 46        window.spawn(cx, async move |cx| {
 47            let command = kernel_spec
 48                .kernelspec
 49                .argv
 50                .first()
 51                .cloned()
 52                .unwrap_or_default();
 53            let args = kernel_spec
 54                .kernelspec
 55                .argv
 56                .iter()
 57                .skip(1)
 58                .cloned()
 59                .collect();
 60
 61            let request = proto::SpawnKernel {
 62                kernel_name: kernel_spec.name.clone(),
 63                working_directory: working_directory.to_string_lossy().to_string(),
 64                project_id,
 65                command,
 66                args,
 67            };
 68            let response = if let Some(remote_client) = remote_client.as_ref() {
 69                remote_client
 70                    .read_with(cx, |client, _| client.proto_client())
 71                    .request(request)
 72                    .await?
 73            } else {
 74                client.request(request).await?
 75            };
 76
 77            let kernel_id = response.kernel_id.clone();
 78            let connection_info: serde_json::Value =
 79                serde_json::from_str(&response.connection_file)?;
 80
 81            // Setup SSH Tunneling - allocate local ports
 82            let mut local_ports = Vec::new();
 83            for _ in 0..5 {
 84                let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
 85                let port = listener.local_addr()?.port();
 86                drop(listener);
 87                local_ports.push(port);
 88            }
 89
 90            let remote_shell_port = connection_info["shell_port"]
 91                .as_u64()
 92                .context("missing shell_port")? as u16;
 93            let remote_iopub_port = connection_info["iopub_port"]
 94                .as_u64()
 95                .context("missing iopub_port")? as u16;
 96            let remote_stdin_port = connection_info["stdin_port"]
 97                .as_u64()
 98                .context("missing stdin_port")? as u16;
 99            let remote_control_port = connection_info["control_port"]
100                .as_u64()
101                .context("missing control_port")? as u16;
102            let remote_hb_port = connection_info["hb_port"]
103                .as_u64()
104                .context("missing hb_port")? as u16;
105
106            let forwards = vec![
107                (local_ports[0], "127.0.0.1".to_string(), remote_shell_port),
108                (local_ports[1], "127.0.0.1".to_string(), remote_iopub_port),
109                (local_ports[2], "127.0.0.1".to_string(), remote_stdin_port),
110                (local_ports[3], "127.0.0.1".to_string(), remote_control_port),
111                (local_ports[4], "127.0.0.1".to_string(), remote_hb_port),
112            ];
113
114            let remote_client = remote_client.ok_or_else(|| anyhow::anyhow!("no remote client"))?;
115            let command_template = cx.update(|_window, cx| {
116                remote_client.read(cx).build_forward_ports_command(forwards)
117            })??;
118
119            let mut command = util::command::new_command(&command_template.program);
120            command.args(&command_template.args);
121            command.envs(&command_template.env);
122
123            let mut ssh_tunnel_process = command.spawn().context("failed to spawn ssh tunnel")?;
124
125            let stderr = ssh_tunnel_process.stderr.take();
126            cx.spawn(async move |_cx| {
127                if let Some(stderr) = stderr {
128                    let reader = BufReader::new(stderr);
129                    let mut lines = reader.lines();
130                    while let Some(Ok(line)) = lines.next().await {
131                        log::warn!("ssh tunnel stderr: {}", line);
132                    }
133                }
134            })
135            .detach();
136
137            let stdout = ssh_tunnel_process.stdout.take();
138            cx.spawn(async move |_cx| {
139                if let Some(stdout) = stdout {
140                    let reader = BufReader::new(stdout);
141                    let mut lines = reader.lines();
142                    while let Some(Ok(line)) = lines.next().await {
143                        log::debug!("ssh tunnel stdout: {}", line);
144                    }
145                }
146            })
147            .detach();
148
149            // We might or might not need this, perhaps we can just wait for a second or test it this way
150            let shell_port = local_ports[0];
151            let max_attempts = 100;
152            let mut connected = false;
153            for attempt in 0..max_attempts {
154                match smol::net::TcpStream::connect(format!("127.0.0.1:{}", shell_port)).await {
155                    Ok(_) => {
156                        connected = true;
157                        log::info!(
158                            "SSH tunnel established for kernel {} on attempt {}",
159                            kernel_id,
160                            attempt + 1
161                        );
162                        // giving the tunnel a moment to fully establish forwarding
163                        cx.background_executor()
164                            .timer(std::time::Duration::from_millis(500))
165                            .await;
166                        break;
167                    }
168                    Err(err) => {
169                        if attempt % 10 == 0 {
170                            log::debug!(
171                                "Waiting for SSH tunnel (attempt {}/{}): {}",
172                                attempt + 1,
173                                max_attempts,
174                                err
175                            );
176                        }
177                        if attempt < max_attempts - 1 {
178                            cx.background_executor()
179                                .timer(std::time::Duration::from_millis(100))
180                                .await;
181                        }
182                    }
183                }
184            }
185            if !connected {
186                anyhow::bail!(
187                    "SSH tunnel failed to establish after {} attempts",
188                    max_attempts
189                );
190            }
191
192            let mut local_connection_info = connection_info.clone();
193            local_connection_info["shell_port"] = serde_json::json!(local_ports[0]);
194            local_connection_info["iopub_port"] = serde_json::json!(local_ports[1]);
195            local_connection_info["stdin_port"] = serde_json::json!(local_ports[2]);
196            local_connection_info["control_port"] = serde_json::json!(local_ports[3]);
197            local_connection_info["hb_port"] = serde_json::json!(local_ports[4]);
198            local_connection_info["ip"] = serde_json::json!("127.0.0.1");
199
200            let local_connection_file =
201                std::env::temp_dir().join(format!("zed_ssh_kernel_{}.json", kernel_id));
202            std::fs::write(
203                &local_connection_file,
204                serde_json::to_string_pretty(&local_connection_info)?,
205            )?;
206
207            // Parse connection info and create ZMQ connections
208            let connection_info_struct: runtimelib::ConnectionInfo =
209                serde_json::from_value(local_connection_info)?;
210            let session_id = uuid::Uuid::new_v4().to_string();
211
212            let output_socket = runtimelib::create_client_iopub_connection(
213                &connection_info_struct,
214                "",
215                &session_id,
216            )
217            .await
218            .context("failed to create iopub connection")?;
219
220            let peer_identity = runtimelib::peer_identity_for_session(&session_id)?;
221            let shell_socket = runtimelib::create_client_shell_connection_with_identity(
222                &connection_info_struct,
223                &session_id,
224                peer_identity.clone(),
225            )
226            .await
227            .context("failed to create shell connection")?;
228            let control_socket =
229                runtimelib::create_client_control_connection(&connection_info_struct, &session_id)
230                    .await
231                    .context("failed to create control connection")?;
232            let stdin_socket = runtimelib::create_client_stdin_connection_with_identity(
233                &connection_info_struct,
234                &session_id,
235                peer_identity,
236            )
237            .await
238            .context("failed to create stdin connection")?;
239
240            let (request_tx, stdin_tx) = start_kernel_tasks(
241                session.clone(),
242                output_socket,
243                shell_socket,
244                control_socket,
245                stdin_socket,
246                cx,
247            );
248
249            Ok(Box::new(SshRunningKernel {
250                request_tx,
251                stdin_tx,
252                execution_state: ExecutionState::Idle,
253                kernel_info: None,
254                working_directory,
255                _ssh_tunnel_process: ssh_tunnel_process,
256                _local_connection_file: local_connection_file,
257                kernel_id,
258                project,
259                project_id,
260            }) as Box<dyn RunningKernel>)
261        })
262    }
263}
264
265impl RunningKernel for SshRunningKernel {
266    fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
267        self.request_tx.clone()
268    }
269
270    fn stdin_tx(&self) -> mpsc::Sender<JupyterMessage> {
271        self.stdin_tx.clone()
272    }
273
274    fn working_directory(&self) -> &PathBuf {
275        &self.working_directory
276    }
277
278    fn execution_state(&self) -> &ExecutionState {
279        &self.execution_state
280    }
281
282    fn set_execution_state(&mut self, state: ExecutionState) {
283        self.execution_state = state;
284    }
285
286    fn kernel_info(&self) -> Option<&KernelInfoReply> {
287        self.kernel_info.as_ref()
288    }
289
290    fn set_kernel_info(&mut self, info: KernelInfoReply) {
291        self.kernel_info = Some(info);
292    }
293
294    fn force_shutdown(&mut self, _window: &mut Window, cx: &mut App) -> Task<Result<()>> {
295        let kernel_id = self.kernel_id.clone();
296        let project_id = self.project_id;
297        let client = self.project.read(cx).client();
298
299        cx.background_executor().spawn(async move {
300            let request = proto::KillKernel {
301                kernel_id,
302                project_id,
303            };
304            client.request::<proto::KillKernel>(request).await?;
305            Ok(())
306        })
307    }
308
309    fn kill(&mut self) {
310        self._ssh_tunnel_process.kill().log_err();
311    }
312}