1use super::{KernelSession, RunningKernel, SshRemoteKernelSpecification, start_kernel_tasks};
2use anyhow::{Context as _, Result};
3use client::proto;
4
5use futures::{
6 AsyncBufReadExt as _, StreamExt as _,
7 channel::mpsc::{self},
8 io::BufReader,
9};
10use gpui::{App, Entity, Task, Window};
11use project::Project;
12use runtimelib::{ExecutionState, JupyterMessage, KernelInfoReply};
13use std::path::PathBuf;
14use util::ResultExt;
15
16#[derive(Debug)]
17pub struct SshRunningKernel {
18 request_tx: mpsc::Sender<JupyterMessage>,
19 stdin_tx: mpsc::Sender<JupyterMessage>,
20 execution_state: ExecutionState,
21 kernel_info: Option<KernelInfoReply>,
22 working_directory: PathBuf,
23 _ssh_tunnel_process: util::command::Child,
24 _local_connection_file: PathBuf,
25 kernel_id: String,
26 project: Entity<Project>,
27 project_id: u64,
28}
29
30impl SshRunningKernel {
31 pub fn new<S: KernelSession + 'static>(
32 kernel_spec: SshRemoteKernelSpecification,
33 working_directory: PathBuf,
34 project: Entity<Project>,
35 session: Entity<S>,
36 window: &mut Window,
37 cx: &mut App,
38 ) -> Task<Result<Box<dyn RunningKernel>>> {
39 let client = project.read(cx).client();
40 let remote_client = project.read(cx).remote_client();
41 let project_id = project
42 .read(cx)
43 .remote_id()
44 .unwrap_or(proto::REMOTE_SERVER_PROJECT_ID);
45
46 window.spawn(cx, async move |cx| {
47 let command = kernel_spec
48 .kernelspec
49 .argv
50 .first()
51 .cloned()
52 .unwrap_or_default();
53 let args = kernel_spec
54 .kernelspec
55 .argv
56 .iter()
57 .skip(1)
58 .cloned()
59 .collect();
60
61 let request = proto::SpawnKernel {
62 kernel_name: kernel_spec.name.clone(),
63 working_directory: working_directory.to_string_lossy().to_string(),
64 project_id,
65 command,
66 args,
67 };
68 let response = if let Some(remote_client) = remote_client.as_ref() {
69 remote_client
70 .read_with(cx, |client, _| client.proto_client())
71 .request(request)
72 .await?
73 } else {
74 client.request(request).await?
75 };
76
77 let kernel_id = response.kernel_id.clone();
78 let connection_info: serde_json::Value =
79 serde_json::from_str(&response.connection_file)?;
80
81 // Setup SSH Tunneling - allocate local ports
82 let mut local_ports = Vec::new();
83 for _ in 0..5 {
84 let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
85 let port = listener.local_addr()?.port();
86 drop(listener);
87 local_ports.push(port);
88 }
89
90 let remote_shell_port = connection_info["shell_port"]
91 .as_u64()
92 .context("missing shell_port")? as u16;
93 let remote_iopub_port = connection_info["iopub_port"]
94 .as_u64()
95 .context("missing iopub_port")? as u16;
96 let remote_stdin_port = connection_info["stdin_port"]
97 .as_u64()
98 .context("missing stdin_port")? as u16;
99 let remote_control_port = connection_info["control_port"]
100 .as_u64()
101 .context("missing control_port")? as u16;
102 let remote_hb_port = connection_info["hb_port"]
103 .as_u64()
104 .context("missing hb_port")? as u16;
105
106 let forwards = vec![
107 (local_ports[0], "127.0.0.1".to_string(), remote_shell_port),
108 (local_ports[1], "127.0.0.1".to_string(), remote_iopub_port),
109 (local_ports[2], "127.0.0.1".to_string(), remote_stdin_port),
110 (local_ports[3], "127.0.0.1".to_string(), remote_control_port),
111 (local_ports[4], "127.0.0.1".to_string(), remote_hb_port),
112 ];
113
114 let remote_client = remote_client.ok_or_else(|| anyhow::anyhow!("no remote client"))?;
115 let command_template = cx.update(|_window, cx| {
116 remote_client.read(cx).build_forward_ports_command(forwards)
117 })??;
118
119 let mut command = util::command::new_command(&command_template.program);
120 command.args(&command_template.args);
121 command.envs(&command_template.env);
122
123 let mut ssh_tunnel_process = command.spawn().context("failed to spawn ssh tunnel")?;
124
125 let stderr = ssh_tunnel_process.stderr.take();
126 cx.spawn(async move |_cx| {
127 if let Some(stderr) = stderr {
128 let reader = BufReader::new(stderr);
129 let mut lines = reader.lines();
130 while let Some(Ok(line)) = lines.next().await {
131 log::warn!("ssh tunnel stderr: {}", line);
132 }
133 }
134 })
135 .detach();
136
137 let stdout = ssh_tunnel_process.stdout.take();
138 cx.spawn(async move |_cx| {
139 if let Some(stdout) = stdout {
140 let reader = BufReader::new(stdout);
141 let mut lines = reader.lines();
142 while let Some(Ok(line)) = lines.next().await {
143 log::debug!("ssh tunnel stdout: {}", line);
144 }
145 }
146 })
147 .detach();
148
149 // We might or might not need this, perhaps we can just wait for a second or test it this way
150 let shell_port = local_ports[0];
151 let max_attempts = 100;
152 let mut connected = false;
153 for attempt in 0..max_attempts {
154 match smol::net::TcpStream::connect(format!("127.0.0.1:{}", shell_port)).await {
155 Ok(_) => {
156 connected = true;
157 log::info!(
158 "SSH tunnel established for kernel {} on attempt {}",
159 kernel_id,
160 attempt + 1
161 );
162 // giving the tunnel a moment to fully establish forwarding
163 cx.background_executor()
164 .timer(std::time::Duration::from_millis(500))
165 .await;
166 break;
167 }
168 Err(err) => {
169 if attempt % 10 == 0 {
170 log::debug!(
171 "Waiting for SSH tunnel (attempt {}/{}): {}",
172 attempt + 1,
173 max_attempts,
174 err
175 );
176 }
177 if attempt < max_attempts - 1 {
178 cx.background_executor()
179 .timer(std::time::Duration::from_millis(100))
180 .await;
181 }
182 }
183 }
184 }
185 if !connected {
186 anyhow::bail!(
187 "SSH tunnel failed to establish after {} attempts",
188 max_attempts
189 );
190 }
191
192 let mut local_connection_info = connection_info.clone();
193 local_connection_info["shell_port"] = serde_json::json!(local_ports[0]);
194 local_connection_info["iopub_port"] = serde_json::json!(local_ports[1]);
195 local_connection_info["stdin_port"] = serde_json::json!(local_ports[2]);
196 local_connection_info["control_port"] = serde_json::json!(local_ports[3]);
197 local_connection_info["hb_port"] = serde_json::json!(local_ports[4]);
198 local_connection_info["ip"] = serde_json::json!("127.0.0.1");
199
200 let local_connection_file =
201 std::env::temp_dir().join(format!("zed_ssh_kernel_{}.json", kernel_id));
202 std::fs::write(
203 &local_connection_file,
204 serde_json::to_string_pretty(&local_connection_info)?,
205 )?;
206
207 // Parse connection info and create ZMQ connections
208 let connection_info_struct: runtimelib::ConnectionInfo =
209 serde_json::from_value(local_connection_info)?;
210 let session_id = uuid::Uuid::new_v4().to_string();
211
212 let output_socket = runtimelib::create_client_iopub_connection(
213 &connection_info_struct,
214 "",
215 &session_id,
216 )
217 .await
218 .context("failed to create iopub connection")?;
219
220 let peer_identity = runtimelib::peer_identity_for_session(&session_id)?;
221 let shell_socket = runtimelib::create_client_shell_connection_with_identity(
222 &connection_info_struct,
223 &session_id,
224 peer_identity.clone(),
225 )
226 .await
227 .context("failed to create shell connection")?;
228 let control_socket =
229 runtimelib::create_client_control_connection(&connection_info_struct, &session_id)
230 .await
231 .context("failed to create control connection")?;
232 let stdin_socket = runtimelib::create_client_stdin_connection_with_identity(
233 &connection_info_struct,
234 &session_id,
235 peer_identity,
236 )
237 .await
238 .context("failed to create stdin connection")?;
239
240 let (request_tx, stdin_tx) = start_kernel_tasks(
241 session.clone(),
242 output_socket,
243 shell_socket,
244 control_socket,
245 stdin_socket,
246 cx,
247 );
248
249 Ok(Box::new(SshRunningKernel {
250 request_tx,
251 stdin_tx,
252 execution_state: ExecutionState::Idle,
253 kernel_info: None,
254 working_directory,
255 _ssh_tunnel_process: ssh_tunnel_process,
256 _local_connection_file: local_connection_file,
257 kernel_id,
258 project,
259 project_id,
260 }) as Box<dyn RunningKernel>)
261 })
262 }
263}
264
265impl RunningKernel for SshRunningKernel {
266 fn request_tx(&self) -> mpsc::Sender<JupyterMessage> {
267 self.request_tx.clone()
268 }
269
270 fn stdin_tx(&self) -> mpsc::Sender<JupyterMessage> {
271 self.stdin_tx.clone()
272 }
273
274 fn working_directory(&self) -> &PathBuf {
275 &self.working_directory
276 }
277
278 fn execution_state(&self) -> &ExecutionState {
279 &self.execution_state
280 }
281
282 fn set_execution_state(&mut self, state: ExecutionState) {
283 self.execution_state = state;
284 }
285
286 fn kernel_info(&self) -> Option<&KernelInfoReply> {
287 self.kernel_info.as_ref()
288 }
289
290 fn set_kernel_info(&mut self, info: KernelInfoReply) {
291 self.kernel_info = Some(info);
292 }
293
294 fn force_shutdown(&mut self, _window: &mut Window, cx: &mut App) -> Task<Result<()>> {
295 let kernel_id = self.kernel_id.clone();
296 let project_id = self.project_id;
297 let client = self.project.read(cx).client();
298
299 cx.background_executor().spawn(async move {
300 let request = proto::KillKernel {
301 kernel_id,
302 project_id,
303 };
304 client.request::<proto::KillKernel>(request).await?;
305 Ok(())
306 })
307 }
308
309 fn kill(&mut self) {
310 self._ssh_tunnel_process.kill().log_err();
311 }
312}