1use futures::{channel::mpsc, SinkExt as _};
2use gpui::{Task, View, WindowContext};
3use http_client::{AsyncBody, HttpClient, Request};
4use jupyter_protocol::{ExecutionState, JupyterKernelspec, JupyterMessage, KernelInfoReply};
5
6use futures::StreamExt;
7use smol::io::AsyncReadExt as _;
8
9use crate::Session;
10
11use super::RunningKernel;
12use anyhow::Result;
13use jupyter_websocket_client::{
14 JupyterWebSocketReader, JupyterWebSocketWriter, KernelLaunchRequest, KernelSpecsResponse,
15 RemoteServer,
16};
17use std::{fmt::Debug, sync::Arc};
18
19#[derive(Debug, Clone)]
20pub struct RemoteKernelSpecification {
21 pub name: String,
22 pub url: String,
23 pub token: String,
24 pub kernelspec: JupyterKernelspec,
25}
26
27pub async fn launch_remote_kernel(
28 remote_server: &RemoteServer,
29 http_client: Arc<dyn HttpClient>,
30 kernel_name: &str,
31 _path: &str,
32) -> Result<String> {
33 //
34 let kernel_launch_request = KernelLaunchRequest {
35 name: kernel_name.to_string(),
36 // Note: since the path we have locally may not be the same as the one on the remote server,
37 // we don't send it. We'll have to evaluate this decisiion along the way.
38 path: None,
39 };
40
41 let kernel_launch_request = serde_json::to_string(&kernel_launch_request)?;
42
43 let request = Request::builder()
44 .method("POST")
45 .uri(&remote_server.api_url("/kernels"))
46 .header("Authorization", format!("token {}", remote_server.token))
47 .body(AsyncBody::from(kernel_launch_request))?;
48
49 let response = http_client.send(request).await?;
50
51 if !response.status().is_success() {
52 let mut body = String::new();
53 response.into_body().read_to_string(&mut body).await?;
54 return Err(anyhow::anyhow!("Failed to launch kernel: {}", body));
55 }
56
57 let mut body = String::new();
58 response.into_body().read_to_string(&mut body).await?;
59
60 let response: jupyter_websocket_client::Kernel = serde_json::from_str(&body)?;
61
62 Ok(response.id)
63}
64
65pub async fn list_remote_kernelspecs(
66 remote_server: RemoteServer,
67 http_client: Arc<dyn HttpClient>,
68) -> Result<Vec<RemoteKernelSpecification>> {
69 let url = remote_server.api_url("/kernelspecs");
70
71 let request = Request::builder()
72 .method("GET")
73 .uri(&url)
74 .header("Authorization", format!("token {}", remote_server.token))
75 .body(AsyncBody::default())?;
76
77 let response = http_client.send(request).await?;
78
79 if response.status().is_success() {
80 let mut body = response.into_body();
81
82 let mut body_bytes = Vec::new();
83 body.read_to_end(&mut body_bytes).await?;
84
85 let kernel_specs: KernelSpecsResponse = serde_json::from_slice(&body_bytes)?;
86
87 let remote_kernelspecs = kernel_specs
88 .kernelspecs
89 .into_iter()
90 .map(|(name, spec)| RemoteKernelSpecification {
91 name: name.clone(),
92 url: remote_server.base_url.clone(),
93 token: remote_server.token.clone(),
94 kernelspec: spec.spec,
95 })
96 .collect::<Vec<RemoteKernelSpecification>>();
97
98 if remote_kernelspecs.is_empty() {
99 Err(anyhow::anyhow!("No kernel specs found"))
100 } else {
101 Ok(remote_kernelspecs.clone())
102 }
103 } else {
104 Err(anyhow::anyhow!(
105 "Failed to fetch kernel specs: {}",
106 response.status()
107 ))
108 }
109}
110
111impl PartialEq for RemoteKernelSpecification {
112 fn eq(&self, other: &Self) -> bool {
113 self.name == other.name && self.url == other.url
114 }
115}
116
117impl Eq for RemoteKernelSpecification {}
118
119pub struct RemoteRunningKernel {
120 remote_server: RemoteServer,
121 _receiving_task: Task<Result<()>>,
122 _routing_task: Task<Result<()>>,
123 http_client: Arc<dyn HttpClient>,
124 pub working_directory: std::path::PathBuf,
125 pub request_tx: mpsc::Sender<JupyterMessage>,
126 pub execution_state: ExecutionState,
127 pub kernel_info: Option<KernelInfoReply>,
128 pub kernel_id: String,
129}
130
131impl RemoteRunningKernel {
132 pub fn new(
133 kernelspec: RemoteKernelSpecification,
134 working_directory: std::path::PathBuf,
135 session: View<Session>,
136 cx: &mut WindowContext,
137 ) -> Task<Result<Box<dyn RunningKernel>>> {
138 let remote_server = RemoteServer {
139 base_url: kernelspec.url,
140 token: kernelspec.token,
141 };
142
143 let http_client = cx.http_client();
144
145 cx.spawn(|cx| async move {
146 let kernel_id = launch_remote_kernel(
147 &remote_server,
148 http_client.clone(),
149 &kernelspec.name,
150 working_directory.to_str().unwrap_or_default(),
151 )
152 .await?;
153
154 let (kernel_socket, _response) = remote_server.connect_to_kernel(&kernel_id).await?;
155
156 let (mut w, mut r): (JupyterWebSocketWriter, JupyterWebSocketReader) =
157 kernel_socket.split();
158
159 let (request_tx, mut request_rx) =
160 futures::channel::mpsc::channel::<JupyterMessage>(100);
161
162 let routing_task = cx.background_executor().spawn({
163 async move {
164 while let Some(message) = request_rx.next().await {
165 w.send(message).await.ok();
166 }
167 Ok(())
168 }
169 });
170
171 let receiving_task = cx.spawn({
172 let session = session.clone();
173
174 |mut cx| async move {
175 while let Some(message) = r.next().await {
176 match message {
177 Ok(message) => {
178 session
179 .update(&mut cx, |session, cx| {
180 session.route(&message, cx);
181 })
182 .ok();
183 }
184 Err(e) => {
185 log::error!("Error receiving message: {:?}", e);
186 }
187 }
188 }
189 Ok(())
190 }
191 });
192
193 anyhow::Ok(Box::new(Self {
194 _routing_task: routing_task,
195 _receiving_task: receiving_task,
196 remote_server,
197 working_directory,
198 request_tx,
199 // todo(kyle): pull this from the kernel API to start with
200 execution_state: ExecutionState::Idle,
201 kernel_info: None,
202 kernel_id,
203 http_client: http_client.clone(),
204 }) as Box<dyn RunningKernel>)
205 })
206 }
207}
208
209impl Debug for RemoteRunningKernel {
210 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
211 f.debug_struct("RemoteRunningKernel")
212 // custom debug that keeps tokens out of logs
213 .field("remote_server url", &self.remote_server.base_url)
214 .field("working_directory", &self.working_directory)
215 .field("request_tx", &self.request_tx)
216 .field("execution_state", &self.execution_state)
217 .field("kernel_info", &self.kernel_info)
218 .finish()
219 }
220}
221
222impl RunningKernel for RemoteRunningKernel {
223 fn request_tx(&self) -> futures::channel::mpsc::Sender<runtimelib::JupyterMessage> {
224 self.request_tx.clone()
225 }
226
227 fn working_directory(&self) -> &std::path::PathBuf {
228 &self.working_directory
229 }
230
231 fn execution_state(&self) -> &runtimelib::ExecutionState {
232 &self.execution_state
233 }
234
235 fn set_execution_state(&mut self, state: runtimelib::ExecutionState) {
236 self.execution_state = state;
237 }
238
239 fn kernel_info(&self) -> Option<&runtimelib::KernelInfoReply> {
240 self.kernel_info.as_ref()
241 }
242
243 fn set_kernel_info(&mut self, info: runtimelib::KernelInfoReply) {
244 self.kernel_info = Some(info);
245 }
246
247 fn force_shutdown(&mut self, cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
248 let url = self
249 .remote_server
250 .api_url(&format!("/kernels/{}", self.kernel_id));
251 let token = self.remote_server.token.clone();
252 let http_client = self.http_client.clone();
253
254 cx.spawn(|_| async move {
255 let request = Request::builder()
256 .method("DELETE")
257 .uri(&url)
258 .header("Authorization", format!("token {}", token))
259 .body(AsyncBody::default())?;
260
261 let response = http_client.send(request).await?;
262
263 if response.status().is_success() {
264 Ok(())
265 } else {
266 Err(anyhow::anyhow!(
267 "Failed to shutdown kernel: {}",
268 response.status()
269 ))
270 }
271 })
272 }
273}