remote_kernels.rs

  1use futures::{channel::mpsc, SinkExt as _};
  2use gpui::{Task, View, WindowContext};
  3use http_client::{AsyncBody, HttpClient, Request};
  4use jupyter_protocol::{ExecutionState, JupyterMessage, KernelInfoReply};
  5use runtimelib::JupyterKernelspec;
  6
  7use futures::StreamExt;
  8use smol::io::AsyncReadExt as _;
  9
 10use crate::Session;
 11
 12use super::RunningKernel;
 13use anyhow::Result;
 14use jupyter_websocket_client::{
 15    JupyterWebSocketReader, JupyterWebSocketWriter, KernelLaunchRequest, KernelSpecsResponse,
 16    RemoteServer,
 17};
 18use std::{fmt::Debug, sync::Arc};
 19
 20#[derive(Debug, Clone)]
 21pub struct RemoteKernelSpecification {
 22    pub name: String,
 23    pub url: String,
 24    pub token: String,
 25    pub kernelspec: JupyterKernelspec,
 26}
 27
 28pub async fn launch_remote_kernel(
 29    remote_server: &RemoteServer,
 30    http_client: Arc<dyn HttpClient>,
 31    kernel_name: &str,
 32    _path: &str,
 33) -> Result<String> {
 34    //
 35    let kernel_launch_request = KernelLaunchRequest {
 36        name: kernel_name.to_string(),
 37        // todo: add path to runtimelib
 38        // path,
 39    };
 40
 41    let kernel_launch_request = serde_json::to_string(&kernel_launch_request)?;
 42
 43    let request = Request::builder()
 44        .method("POST")
 45        .uri(&remote_server.api_url("/kernels"))
 46        .header("Authorization", format!("token {}", remote_server.token))
 47        .body(AsyncBody::from(kernel_launch_request))?;
 48
 49    let response = http_client.send(request).await?;
 50
 51    if !response.status().is_success() {
 52        let mut body = String::new();
 53        response.into_body().read_to_string(&mut body).await?;
 54        return Err(anyhow::anyhow!("Failed to launch kernel: {}", body));
 55    }
 56
 57    let mut body = String::new();
 58    response.into_body().read_to_string(&mut body).await?;
 59
 60    let response: jupyter_websocket_client::Kernel = serde_json::from_str(&body)?;
 61
 62    Ok(response.id)
 63}
 64
 65pub async fn list_remote_kernelspecs(
 66    remote_server: RemoteServer,
 67    http_client: Arc<dyn HttpClient>,
 68) -> Result<Vec<RemoteKernelSpecification>> {
 69    let url = remote_server.api_url("/kernelspecs");
 70
 71    let request = Request::builder()
 72        .method("GET")
 73        .uri(&url)
 74        .header("Authorization", format!("token {}", remote_server.token))
 75        .body(AsyncBody::default())?;
 76
 77    let response = http_client.send(request).await?;
 78
 79    if response.status().is_success() {
 80        let mut body = response.into_body();
 81
 82        let mut body_bytes = Vec::new();
 83        body.read_to_end(&mut body_bytes).await?;
 84
 85        let kernel_specs: KernelSpecsResponse = serde_json::from_slice(&body_bytes)?;
 86
 87        let remote_kernelspecs = kernel_specs
 88            .kernelspecs
 89            .into_iter()
 90            .map(|(name, spec)| RemoteKernelSpecification {
 91                name: name.clone(),
 92                url: remote_server.base_url.clone(),
 93                token: remote_server.token.clone(),
 94                // todo: line up the jupyter kernelspec from runtimelib with
 95                //       the kernelspec pulled from the API
 96                //
 97                //        There are _small_ differences, so we may just want a impl `From`
 98                kernelspec: JupyterKernelspec {
 99                    argv: spec.spec.argv,
100                    display_name: spec.spec.display_name,
101                    language: spec.spec.language,
102                    // todo: fix up mismatch in types here
103                    metadata: None,
104                    interrupt_mode: None,
105                    env: None,
106                },
107            })
108            .collect::<Vec<RemoteKernelSpecification>>();
109
110        if remote_kernelspecs.is_empty() {
111            Err(anyhow::anyhow!("No kernel specs found"))
112        } else {
113            Ok(remote_kernelspecs.clone())
114        }
115    } else {
116        Err(anyhow::anyhow!(
117            "Failed to fetch kernel specs: {}",
118            response.status()
119        ))
120    }
121}
122
123impl PartialEq for RemoteKernelSpecification {
124    fn eq(&self, other: &Self) -> bool {
125        self.name == other.name && self.url == other.url
126    }
127}
128
129impl Eq for RemoteKernelSpecification {}
130
131pub struct RemoteRunningKernel {
132    remote_server: RemoteServer,
133    _receiving_task: Task<Result<()>>,
134    _routing_task: Task<Result<()>>,
135    http_client: Arc<dyn HttpClient>,
136    pub working_directory: std::path::PathBuf,
137    pub request_tx: mpsc::Sender<JupyterMessage>,
138    pub execution_state: ExecutionState,
139    pub kernel_info: Option<KernelInfoReply>,
140    pub kernel_id: String,
141}
142
143impl RemoteRunningKernel {
144    pub fn new(
145        kernelspec: RemoteKernelSpecification,
146        working_directory: std::path::PathBuf,
147        session: View<Session>,
148        cx: &mut WindowContext,
149    ) -> Task<Result<Box<dyn RunningKernel>>> {
150        let remote_server = RemoteServer {
151            base_url: kernelspec.url,
152            token: kernelspec.token,
153        };
154
155        let http_client = cx.http_client();
156
157        cx.spawn(|cx| async move {
158            let kernel_id = launch_remote_kernel(
159                &remote_server,
160                http_client.clone(),
161                &kernelspec.name,
162                working_directory.to_str().unwrap_or_default(),
163            )
164            .await?;
165
166            let kernel_socket = remote_server.connect_to_kernel(&kernel_id).await?;
167
168            let (mut w, mut r): (JupyterWebSocketWriter, JupyterWebSocketReader) =
169                kernel_socket.split();
170
171            let (request_tx, mut request_rx) =
172                futures::channel::mpsc::channel::<JupyterMessage>(100);
173
174            let routing_task = cx.background_executor().spawn({
175                async move {
176                    while let Some(message) = request_rx.next().await {
177                        w.send(message).await.ok();
178                    }
179                    Ok(())
180                }
181            });
182
183            let receiving_task = cx.spawn({
184                let session = session.clone();
185
186                |mut cx| async move {
187                    while let Some(message) = r.next().await {
188                        match message {
189                            Ok(message) => {
190                                session
191                                    .update(&mut cx, |session, cx| {
192                                        session.route(&message, cx);
193                                    })
194                                    .ok();
195                            }
196                            Err(e) => {
197                                log::error!("Error receiving message: {:?}", e);
198                            }
199                        }
200                    }
201                    Ok(())
202                }
203            });
204
205            anyhow::Ok(Box::new(Self {
206                _routing_task: routing_task,
207                _receiving_task: receiving_task,
208                remote_server,
209                working_directory,
210                request_tx,
211                // todo(kyle): pull this from the kernel API to start with
212                execution_state: ExecutionState::Idle,
213                kernel_info: None,
214                kernel_id,
215                http_client: http_client.clone(),
216            }) as Box<dyn RunningKernel>)
217        })
218    }
219}
220
221impl Debug for RemoteRunningKernel {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        f.debug_struct("RemoteRunningKernel")
224            // custom debug that keeps tokens out of logs
225            .field("remote_server url", &self.remote_server.base_url)
226            .field("working_directory", &self.working_directory)
227            .field("request_tx", &self.request_tx)
228            .field("execution_state", &self.execution_state)
229            .field("kernel_info", &self.kernel_info)
230            .finish()
231    }
232}
233
234impl RunningKernel for RemoteRunningKernel {
235    fn request_tx(&self) -> futures::channel::mpsc::Sender<runtimelib::JupyterMessage> {
236        self.request_tx.clone()
237    }
238
239    fn working_directory(&self) -> &std::path::PathBuf {
240        &self.working_directory
241    }
242
243    fn execution_state(&self) -> &runtimelib::ExecutionState {
244        &self.execution_state
245    }
246
247    fn set_execution_state(&mut self, state: runtimelib::ExecutionState) {
248        self.execution_state = state;
249    }
250
251    fn kernel_info(&self) -> Option<&runtimelib::KernelInfoReply> {
252        self.kernel_info.as_ref()
253    }
254
255    fn set_kernel_info(&mut self, info: runtimelib::KernelInfoReply) {
256        self.kernel_info = Some(info);
257    }
258
259    fn force_shutdown(&mut self, cx: &mut WindowContext) -> Task<anyhow::Result<()>> {
260        let url = self
261            .remote_server
262            .api_url(&format!("/kernels/{}", self.kernel_id));
263        let token = self.remote_server.token.clone();
264        let http_client = self.http_client.clone();
265
266        cx.spawn(|_| async move {
267            let request = Request::builder()
268                .method("DELETE")
269                .uri(&url)
270                .header("Authorization", format!("token {}", token))
271                .body(AsyncBody::default())?;
272
273            let response = http_client.send(request).await?;
274
275            if response.status().is_success() {
276                Ok(())
277            } else {
278                Err(anyhow::anyhow!(
279                    "Failed to shutdown kernel: {}",
280                    response.status()
281                ))
282            }
283        })
284    }
285}