1use futures::{channel::mpsc, SinkExt as _};
2use gpui::{Task, View, WindowContext};
3use http_client::{AsyncBody, HttpClient, Request};
4use jupyter_protocol::{ExecutionState, JupyterMessage, KernelInfoReply};
5use runtimelib::JupyterKernelspec;
6
7use futures::StreamExt;
8use smol::io::AsyncReadExt as _;
9
10use crate::Session;
11
12use super::RunningKernel;
13use anyhow::Result;
14use jupyter_websocket_client::{
15 JupyterWebSocketReader, JupyterWebSocketWriter, KernelLaunchRequest, KernelSpecsResponse,
16 RemoteServer,
17};
18use std::{fmt::Debug, sync::Arc};
19
20#[derive(Debug, Clone)]
21pub struct RemoteKernelSpecification {
22 pub name: String,
23 pub url: String,
24 pub token: String,
25 pub kernelspec: JupyterKernelspec,
26}
27
28pub async fn launch_remote_kernel(
29 remote_server: &RemoteServer,
30 http_client: Arc<dyn HttpClient>,
31 kernel_name: &str,
32 _path: &str,
33) -> Result<String> {
34 //
35 let kernel_launch_request = KernelLaunchRequest {
36 name: kernel_name.to_string(),
37 // todo: add path to runtimelib
38 // path,
39 };
40
41 let kernel_launch_request = serde_json::to_string(&kernel_launch_request)?;
42
43 let request = Request::builder()
44 .method("POST")
45 .uri(&remote_server.api_url("/kernels"))
46 .header("Authorization", format!("token {}", remote_server.token))
47 .body(AsyncBody::from(kernel_launch_request))?;
48
49 let response = http_client.send(request).await?;
50
51 if !response.status().is_success() {
52 let mut body = String::new();
53 response.into_body().read_to_string(&mut body).await?;
54 return Err(anyhow::anyhow!("Failed to launch kernel: {}", body));
55 }
56
57 let mut body = String::new();
58 response.into_body().read_to_string(&mut body).await?;
59
60 let response: jupyter_websocket_client::Kernel = serde_json::from_str(&body)?;
61
62 Ok(response.id)
63}
64
65pub async fn list_remote_kernelspecs(
66 remote_server: RemoteServer,
67 http_client: Arc<dyn HttpClient>,
68) -> Result<Vec<RemoteKernelSpecification>> {
69 let url = remote_server.api_url("/kernelspecs");
70
71 let request = Request::builder()
72 .method("GET")
73 .uri(&url)
74 .header("Authorization", format!("token {}", remote_server.token))
75 .body(AsyncBody::default())?;
76
77 let response = http_client.send(request).await?;
78
79 if response.status().is_success() {
80 let mut body = response.into_body();
81
82 let mut body_bytes = Vec::new();
83 body.read_to_end(&mut body_bytes).await?;
84
85 let kernel_specs: KernelSpecsResponse = serde_json::from_slice(&body_bytes)?;
86
87 let remote_kernelspecs = kernel_specs
88 .kernelspecs
89 .into_iter()
90 .map(|(name, spec)| RemoteKernelSpecification {
91 name: name.clone(),
92 url: remote_server.base_url.clone(),
93 token: remote_server.token.clone(),
94 // todo: line up the jupyter kernelspec from runtimelib with
95 // the kernelspec pulled from the API
96 //
97 // There are _small_ differences, so we may just want a impl `From`
98 kernelspec: JupyterKernelspec {
99 argv: spec.spec.argv,
100 display_name: spec.spec.display_name,
101 language: spec.spec.language,
102 // todo: fix up mismatch in types here
103 metadata: None,
104 interrupt_mode: None,
105 env: None,
106 },
107 })
108 .collect::<Vec<RemoteKernelSpecification>>();
109
110 if remote_kernelspecs.is_empty() {
111 Err(anyhow::anyhow!("No kernel specs found"))
112 } else {
113 Ok(remote_kernelspecs.clone())
114 }
115 } else {
116 Err(anyhow::anyhow!(
117 "Failed to fetch kernel specs: {}",
118 response.status()
119 ))
120 }
121}
122
123impl PartialEq for RemoteKernelSpecification {
124 fn eq(&self, other: &Self) -> bool {
125 self.name == other.name && self.url == other.url
126 }
127}
128
129impl Eq for RemoteKernelSpecification {}
130
131pub struct RemoteRunningKernel {
132 remote_server: RemoteServer,
133 _receiving_task: Task<Result<()>>,
134 _routing_task: Task<Result<()>>,
135 http_client: Arc<dyn HttpClient>,
136 pub working_directory: std::path::PathBuf,
137 pub request_tx: mpsc::Sender<JupyterMessage>,
138 pub execution_state: ExecutionState,
139 pub kernel_info: Option<KernelInfoReply>,
140 pub kernel_id: String,
141}
142
143impl RemoteRunningKernel {
144 pub fn new(
145 kernelspec: RemoteKernelSpecification,
146 working_directory: std::path::PathBuf,
147 session: View<Session>,
148 cx: &mut WindowContext,
149 ) -> Task<Result<Box<dyn RunningKernel>>> {
150 let remote_server = RemoteServer {
151 base_url: kernelspec.url,
152 token: kernelspec.token,
153 };
154
155 let http_client = cx.http_client();
156
157 cx.spawn(|cx| async move {
158 let kernel_id = launch_remote_kernel(
159 &remote_server,
160 http_client.clone(),
161 &kernelspec.name,
162 working_directory.to_str().unwrap_or_default(),
163 )
164 .await?;
165
166 let kernel_socket = remote_server.connect_to_kernel(&kernel_id).await?;
167
168 let (mut w, mut r): (JupyterWebSocketWriter, JupyterWebSocketReader) =
169 kernel_socket.split();
170
171 let (request_tx, mut request_rx) =
172 futures::channel::mpsc::channel::<JupyterMessage>(100);
173
174 let routing_task = cx.background_executor().spawn({
175 async move {
176 while let Some(message) = request_rx.next().await {
177 w.send(message).await.ok();
178 }
179 Ok(())
180 }
181 });
182
183 let receiving_task = cx.spawn({
184 let session = session.clone();
185
186 |mut cx| async move {
187 while let Some(message) = r.next().await {
188 match message {
189 Ok(message) => {
190 session
191 .update(&mut cx, |session, cx| {
192 session.route(&message, cx);
193 })
194 .ok();
195 }
196 Err(e) => {
197 log::error!("Error receiving message: {:?}", e);
198 }
199 }
200 }
201 Ok(())
202 }
203 });
204
205 anyhow::Ok(Box::new(Self {
206 _routing_task: routing_task,
207 _receiving_task: receiving_task,
208 remote_server,
209 working_directory,
210 request_tx,
211 // todo(kyle): pull this from the kernel API to start with
212 execution_state: ExecutionState::Idle,
213 kernel_info: None,
214 kernel_id,
215 http_client: http_client.clone(),
216 }) as Box<dyn RunningKernel>)
217 })
218 }
219}
220
221impl Debug for RemoteRunningKernel {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 f.debug_struct("RemoteRunningKernel")
224 // custom debug that keeps tokens out of logs
225 .field("remote_server url", &self.remote_server.base_url)
226 .field("working_directory", &self.working_directory)
227 .field("request_tx", &self.request_tx)
228 .field("execution_state", &self.execution_state)
229 .field("kernel_info", &self.kernel_info)
230 .finish()
231 }
232}
233
234impl RunningKernel for RemoteRunningKernel {
235 fn request_tx(&self) -> futures::channel::mpsc::Sender<runtimelib::JupyterMessage> {
236 self.request_tx.clone()
237 }
238
239 fn working_directory(&self) -> &std::path::PathBuf {
240 &self.working_directory
241 }
242
243 fn execution_state(&self) -> &runtimelib::ExecutionState {
244 &self.execution_state
245 }
246
247 fn set_execution_state(&mut self, state: runtimelib::ExecutionState) {
248 self.execution_state = state;
249 }
250
251 fn kernel_info(&self) -> Option<&runtimelib::KernelInfoReply> {
252 self.kernel_info.as_ref()
253 }
254
255 fn set_kernel_info(&mut self, info: runtimelib::KernelInfoReply) {
256 self.kernel_info = Some(info);
257 }
258
259 fn force_shutdown(&mut self, cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
260 let url = self
261 .remote_server
262 .api_url(&format!("/kernels/{}", self.kernel_id));
263 let token = self.remote_server.token.clone();
264 let http_client = self.http_client.clone();
265
266 cx.spawn(|_| async move {
267 let request = Request::builder()
268 .method("DELETE")
269 .uri(&url)
270 .header("Authorization", format!("token {}", token))
271 .body(AsyncBody::default())?;
272
273 let response = http_client.send(request).await?;
274
275 if response.status().is_success() {
276 Ok(())
277 } else {
278 Err(anyhow::anyhow!(
279 "Failed to shutdown kernel: {}",
280 response.status()
281 ))
282 }
283 })
284 }
285}