ssh_connections.rs

  1use std::{path::PathBuf, sync::Arc, time::Duration};
  2
  3use anyhow::Result;
  4use auto_update::AutoUpdater;
  5use editor::Editor;
  6use futures::channel::oneshot;
  7use gpui::{
  8    percentage, px, Animation, AnimationExt, AnyWindowHandle, AsyncAppContext, DismissEvent,
  9    EventEmitter, FocusableView, ParentElement as _, PromptLevel, Render, SemanticVersion,
 10    SharedString, Task, TextStyleRefinement, Transformation, View,
 11};
 12use gpui::{AppContext, Model};
 13
 14use language::CursorShape;
 15use markdown::{Markdown, MarkdownStyle};
 16use release_channel::{AppVersion, ReleaseChannel};
 17use remote::{SshConnectionOptions, SshPlatform, SshRemoteClient};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsSources};
 21use theme::ThemeSettings;
 22use ui::{
 23    div, h_flex, prelude::*, v_flex, ActiveTheme, Color, Icon, IconName, IconSize,
 24    InteractiveElement, IntoElement, Label, LabelCommon, Styled, ViewContext, VisualContext,
 25    WindowContext,
 26};
 27use workspace::{AppState, ModalView, Workspace};
 28
 29#[derive(Deserialize)]
 30pub struct SshSettings {
 31    pub ssh_connections: Option<Vec<SshConnection>>,
 32}
 33
 34impl SshSettings {
 35    pub fn ssh_connections(&self) -> impl Iterator<Item = SshConnection> {
 36        self.ssh_connections.clone().into_iter().flatten()
 37    }
 38
 39    pub fn args_for(
 40        &self,
 41        host: &str,
 42        port: Option<u16>,
 43        user: &Option<String>,
 44    ) -> Option<Vec<String>> {
 45        self.ssh_connections()
 46            .filter_map(|conn| {
 47                if conn.host == host && &conn.username == user && conn.port == port {
 48                    Some(conn.args)
 49                } else {
 50                    None
 51                }
 52            })
 53            .next()
 54    }
 55    pub fn nickname_for(
 56        &self,
 57        host: &str,
 58        port: Option<u16>,
 59        user: &Option<String>,
 60    ) -> Option<SharedString> {
 61        self.ssh_connections()
 62            .filter_map(|conn| {
 63                if conn.host == host && &conn.username == user && conn.port == port {
 64                    Some(conn.nickname)
 65                } else {
 66                    None
 67                }
 68            })
 69            .next()
 70            .flatten()
 71    }
 72}
 73
 74#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
 75pub struct SshConnection {
 76    pub host: SharedString,
 77    #[serde(skip_serializing_if = "Option::is_none")]
 78    pub username: Option<String>,
 79    #[serde(skip_serializing_if = "Option::is_none")]
 80    pub port: Option<u16>,
 81    pub projects: Vec<SshProject>,
 82    /// Name to use for this server in UI.
 83    #[serde(skip_serializing_if = "Option::is_none")]
 84    pub nickname: Option<SharedString>,
 85    #[serde(skip_serializing_if = "Vec::is_empty")]
 86    #[serde(default)]
 87    pub args: Vec<String>,
 88}
 89impl From<SshConnection> for SshConnectionOptions {
 90    fn from(val: SshConnection) -> Self {
 91        SshConnectionOptions {
 92            host: val.host.into(),
 93            username: val.username,
 94            port: val.port,
 95            password: None,
 96            args: Some(val.args),
 97        }
 98    }
 99}
100
101#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
102pub struct SshProject {
103    pub paths: Vec<String>,
104}
105
106#[derive(Clone, Default, Serialize, Deserialize, JsonSchema)]
107pub struct RemoteSettingsContent {
108    pub ssh_connections: Option<Vec<SshConnection>>,
109}
110
111impl Settings for SshSettings {
112    const KEY: Option<&'static str> = None;
113
114    type FileContent = RemoteSettingsContent;
115
116    fn load(sources: SettingsSources<Self::FileContent>, _: &mut AppContext) -> Result<Self> {
117        sources.json_merge()
118    }
119}
120
121pub struct SshPrompt {
122    connection_string: SharedString,
123    nickname: Option<SharedString>,
124    status_message: Option<SharedString>,
125    prompt: Option<(View<Markdown>, oneshot::Sender<Result<String>>)>,
126    editor: View<Editor>,
127}
128
129pub struct SshConnectionModal {
130    pub(crate) prompt: View<SshPrompt>,
131    finished: bool,
132}
133
134impl SshPrompt {
135    pub(crate) fn new(
136        connection_options: &SshConnectionOptions,
137        nickname: Option<SharedString>,
138        cx: &mut ViewContext<Self>,
139    ) -> Self {
140        let connection_string = connection_options.connection_string().into();
141        Self {
142            connection_string,
143            nickname,
144            status_message: None,
145            prompt: None,
146            editor: cx.new_view(Editor::single_line),
147        }
148    }
149
150    pub fn set_prompt(
151        &mut self,
152        prompt: String,
153        tx: oneshot::Sender<Result<String>>,
154        cx: &mut ViewContext<Self>,
155    ) {
156        let theme = ThemeSettings::get_global(cx);
157
158        let mut text_style = cx.text_style();
159        let refinement = TextStyleRefinement {
160            font_family: Some(theme.buffer_font.family.clone()),
161            font_size: Some(theme.buffer_font_size.into()),
162            color: Some(cx.theme().colors().editor_foreground),
163            background_color: Some(gpui::transparent_black()),
164            ..Default::default()
165        };
166
167        text_style.refine(&refinement);
168        self.editor.update(cx, |editor, cx| {
169            if prompt.contains("yes/no") {
170                editor.set_masked(false, cx);
171            } else {
172                editor.set_masked(true, cx);
173            }
174            editor.set_text_style_refinement(refinement);
175            editor.set_cursor_shape(CursorShape::Block, cx);
176        });
177        let markdown_style = MarkdownStyle {
178            base_text_style: text_style,
179            selection_background_color: cx.theme().players().local().selection,
180            ..Default::default()
181        };
182        let markdown = cx.new_view(|cx| Markdown::new_text(prompt, markdown_style, None, cx, None));
183        self.prompt = Some((markdown, tx));
184        self.status_message.take();
185        cx.focus_view(&self.editor);
186        cx.notify();
187    }
188
189    pub fn set_status(&mut self, status: Option<String>, cx: &mut ViewContext<Self>) {
190        self.status_message = status.map(|s| s.into());
191        cx.notify();
192    }
193
194    pub fn confirm(&mut self, cx: &mut ViewContext<Self>) {
195        if let Some((_, tx)) = self.prompt.take() {
196            self.status_message = Some("Connecting".into());
197            self.editor.update(cx, |editor, cx| {
198                tx.send(Ok(editor.text(cx))).ok();
199                editor.clear(cx);
200            });
201        }
202    }
203}
204
205impl Render for SshPrompt {
206    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
207        let cx = cx.window_context();
208        let theme = cx.theme();
209        v_flex()
210            .key_context("PasswordPrompt")
211            .size_full()
212            .when_some(self.status_message.clone(), |el, status_message| {
213                el.child(
214                    h_flex()
215                        .p_2()
216                        .flex()
217                        .child(
218                            Icon::new(IconName::ArrowCircle)
219                                .size(IconSize::Medium)
220                                .with_animation(
221                                    "arrow-circle",
222                                    Animation::new(Duration::from_secs(2)).repeat(),
223                                    |icon, delta| {
224                                        icon.transform(Transformation::rotate(percentage(delta)))
225                                    },
226                                ),
227                        )
228                        .child(div().ml_1().text_ellipsis().overflow_x_hidden().child(
229                            Label::new(format!("{}", status_message)).size(LabelSize::Small),
230                        )),
231                )
232            })
233            .when_some(self.prompt.as_ref(), |el, prompt| {
234                el.child(
235                    div()
236                        .size_full()
237                        .overflow_hidden()
238                        .p_4()
239                        .border_t_1()
240                        .border_color(theme.colors().border_variant)
241                        .font_buffer(cx)
242                        .text_buffer(cx)
243                        .child(prompt.0.clone())
244                        .child(self.editor.clone()),
245                )
246            })
247    }
248}
249
250impl SshConnectionModal {
251    pub(crate) fn new(
252        connection_options: &SshConnectionOptions,
253        nickname: Option<SharedString>,
254        cx: &mut ViewContext<Self>,
255    ) -> Self {
256        Self {
257            prompt: cx.new_view(|cx| SshPrompt::new(connection_options, nickname, cx)),
258            finished: false,
259        }
260    }
261
262    fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
263        self.prompt.update(cx, |prompt, cx| prompt.confirm(cx))
264    }
265
266    pub fn finished(&mut self, cx: &mut ViewContext<Self>) {
267        self.finished = true;
268        cx.emit(DismissEvent);
269    }
270
271    fn dismiss(&mut self, _: &menu::Cancel, cx: &mut ViewContext<Self>) {
272        cx.emit(DismissEvent);
273    }
274}
275
276pub(crate) struct SshConnectionHeader {
277    pub(crate) connection_string: SharedString,
278    pub(crate) nickname: Option<SharedString>,
279}
280
281impl RenderOnce for SshConnectionHeader {
282    fn render(self, cx: &mut WindowContext) -> impl IntoElement {
283        let theme = cx.theme();
284
285        let mut header_color = theme.colors().text;
286        header_color.fade_out(0.96);
287
288        let (main_label, meta_label) = if let Some(nickname) = self.nickname {
289            (nickname, Some(format!("({})", self.connection_string)))
290        } else {
291            (self.connection_string, None)
292        };
293
294        h_flex()
295            .p_1()
296            .rounded_t_md()
297            .w_full()
298            .gap_2()
299            .justify_center()
300            .border_b_1()
301            .border_color(theme.colors().border_variant)
302            .bg(header_color)
303            .child(Icon::new(IconName::Server).size(IconSize::XSmall))
304            .child(
305                h_flex()
306                    .gap_1()
307                    .child(
308                        Label::new(main_label)
309                            .size(ui::LabelSize::Small)
310                            .single_line(),
311                    )
312                    .children(meta_label.map(|label| {
313                        Label::new(label)
314                            .size(ui::LabelSize::Small)
315                            .single_line()
316                            .color(Color::Muted)
317                    })),
318            )
319    }
320}
321
322impl Render for SshConnectionModal {
323    fn render(&mut self, cx: &mut ui::ViewContext<Self>) -> impl ui::IntoElement {
324        let nickname = self.prompt.read(cx).nickname.clone();
325        let connection_string = self.prompt.read(cx).connection_string.clone();
326        let theme = cx.theme();
327
328        let body_color = theme.colors().editor_background;
329
330        v_flex()
331            .elevation_3(cx)
332            .track_focus(&self.focus_handle(cx))
333            .on_action(cx.listener(Self::dismiss))
334            .on_action(cx.listener(Self::confirm))
335            .w(px(500.))
336            .border_1()
337            .border_color(theme.colors().border)
338            .child(
339                SshConnectionHeader {
340                    connection_string,
341                    nickname,
342                }
343                .render(cx),
344            )
345            .child(
346                h_flex()
347                    .rounded_b_md()
348                    .bg(body_color)
349                    .w_full()
350                    .child(self.prompt.clone()),
351            )
352    }
353}
354
355impl FocusableView for SshConnectionModal {
356    fn focus_handle(&self, cx: &gpui::AppContext) -> gpui::FocusHandle {
357        self.prompt.read(cx).editor.focus_handle(cx)
358    }
359}
360
361impl EventEmitter<DismissEvent> for SshConnectionModal {}
362
363impl ModalView for SshConnectionModal {
364    fn on_before_dismiss(&mut self, _: &mut ViewContext<Self>) -> workspace::DismissDecision {
365        return workspace::DismissDecision::Dismiss(self.finished);
366    }
367
368    fn fade_out_background(&self) -> bool {
369        true
370    }
371}
372
373#[derive(Clone)]
374pub struct SshClientDelegate {
375    window: AnyWindowHandle,
376    ui: View<SshPrompt>,
377    known_password: Option<String>,
378}
379
380impl remote::SshClientDelegate for SshClientDelegate {
381    fn ask_password(
382        &self,
383        prompt: String,
384        cx: &mut AsyncAppContext,
385    ) -> oneshot::Receiver<Result<String>> {
386        let (tx, rx) = oneshot::channel();
387        let mut known_password = self.known_password.clone();
388        if let Some(password) = known_password.take() {
389            tx.send(Ok(password)).ok();
390        } else {
391            self.window
392                .update(cx, |_, cx| {
393                    self.ui.update(cx, |modal, cx| {
394                        modal.set_prompt(prompt, tx, cx);
395                    })
396                })
397                .ok();
398        }
399        rx
400    }
401
402    fn set_status(&self, status: Option<&str>, cx: &mut AsyncAppContext) {
403        self.update_status(status, cx)
404    }
405
406    fn get_server_binary(
407        &self,
408        platform: SshPlatform,
409        cx: &mut AsyncAppContext,
410    ) -> oneshot::Receiver<Result<(PathBuf, SemanticVersion)>> {
411        let (tx, rx) = oneshot::channel();
412        let this = self.clone();
413        cx.spawn(|mut cx| async move {
414            tx.send(this.get_server_binary_impl(platform, &mut cx).await)
415                .ok();
416        })
417        .detach();
418        rx
419    }
420
421    fn remote_server_binary_path(
422        &self,
423        platform: SshPlatform,
424        cx: &mut AsyncAppContext,
425    ) -> Result<PathBuf> {
426        let release_channel = cx.update(|cx| ReleaseChannel::global(cx))?;
427        Ok(paths::remote_server_dir_relative().join(format!(
428            "zed-remote-server-{}-{}-{}",
429            release_channel.dev_name(),
430            platform.os,
431            platform.arch
432        )))
433    }
434}
435
436impl SshClientDelegate {
437    fn update_status(&self, status: Option<&str>, cx: &mut AsyncAppContext) {
438        self.window
439            .update(cx, |_, cx| {
440                self.ui.update(cx, |modal, cx| {
441                    modal.set_status(status.map(|s| s.to_string()), cx);
442                })
443            })
444            .ok();
445    }
446
447    async fn get_server_binary_impl(
448        &self,
449        platform: SshPlatform,
450        cx: &mut AsyncAppContext,
451    ) -> Result<(PathBuf, SemanticVersion)> {
452        let (version, release_channel) = cx.update(|cx| {
453            let global = AppVersion::global(cx);
454            (global, ReleaseChannel::global(cx))
455        })?;
456
457        // In dev mode, build the remote server binary from source
458        #[cfg(debug_assertions)]
459        if release_channel == ReleaseChannel::Dev {
460            let result = self.build_local(cx, platform, version).await?;
461            // Fall through to a remote binary if we're not able to compile a local binary
462            if let Some(result) = result {
463                return Ok(result);
464            }
465        }
466
467        self.update_status(Some("checking for latest version of remote server"), cx);
468        let binary_path = AutoUpdater::get_latest_remote_server_release(
469            platform.os,
470            platform.arch,
471            release_channel,
472            cx,
473        )
474        .await
475        .map_err(|e| {
476            anyhow::anyhow!(
477                "failed to download remote server binary (os: {}, arch: {}): {}",
478                platform.os,
479                platform.arch,
480                e
481            )
482        })?;
483
484        Ok((binary_path, version))
485    }
486
487    #[cfg(debug_assertions)]
488    async fn build_local(
489        &self,
490        cx: &mut AsyncAppContext,
491        platform: SshPlatform,
492        version: SemanticVersion,
493    ) -> Result<Option<(PathBuf, SemanticVersion)>> {
494        use smol::process::{Command, Stdio};
495
496        async fn run_cmd(command: &mut Command) -> Result<()> {
497            let output = command.stderr(Stdio::inherit()).output().await?;
498            if !output.status.success() {
499                Err(anyhow::anyhow!("failed to run command: {:?}", command))?;
500            }
501            Ok(())
502        }
503
504        if platform.arch == std::env::consts::ARCH && platform.os == std::env::consts::OS {
505            self.update_status(Some("Building remote server binary from source"), cx);
506            log::info!("building remote server binary from source");
507            run_cmd(Command::new("cargo").args([
508                "build",
509                "--package",
510                "remote_server",
511                "--features",
512                "debug-embed",
513                "--target-dir",
514                "target/remote_server",
515            ]))
516            .await?;
517
518            self.update_status(Some("Compressing binary"), cx);
519
520            run_cmd(Command::new("gzip").args([
521                "-9",
522                "-f",
523                "target/remote_server/debug/remote_server",
524            ]))
525            .await?;
526
527            let path = std::env::current_dir()?.join("target/remote_server/debug/remote_server.gz");
528            return Ok(Some((path, version)));
529        } else if let Some(triple) = platform.triple() {
530            smol::fs::create_dir_all("target/remote_server").await?;
531
532            self.update_status(Some("Installing cross.rs for cross-compilation"), cx);
533            log::info!("installing cross");
534            run_cmd(Command::new("cargo").args([
535                "install",
536                "cross",
537                "--git",
538                "https://github.com/cross-rs/cross",
539            ]))
540            .await?;
541
542            self.update_status(
543                Some(&format!(
544                    "Building remote server binary from source for {}",
545                    &triple
546                )),
547                cx,
548            );
549            log::info!("building remote server binary from source for {}", &triple);
550            run_cmd(
551                Command::new("cross")
552                    .args([
553                        "build",
554                        "--package",
555                        "remote_server",
556                        "--features",
557                        "debug-embed",
558                        "--target-dir",
559                        "target/remote_server",
560                        "--target",
561                        &triple,
562                    ])
563                    .env(
564                        "CROSS_CONTAINER_OPTS",
565                        "--mount type=bind,src=./target,dst=/app/target",
566                    ),
567            )
568            .await?;
569
570            self.update_status(Some("Compressing binary"), cx);
571
572            run_cmd(Command::new("gzip").args([
573                "-9",
574                "-f",
575                &format!("target/remote_server/{}/debug/remote_server", triple),
576            ]))
577            .await?;
578
579            let path = std::env::current_dir()?.join(format!(
580                "target/remote_server/{}/debug/remote_server.gz",
581                triple
582            ));
583
584            return Ok(Some((path, version)));
585        } else {
586            return Ok(None);
587        }
588    }
589}
590
591pub fn connect_over_ssh(
592    unique_identifier: String,
593    connection_options: SshConnectionOptions,
594    ui: View<SshPrompt>,
595    cx: &mut WindowContext,
596) -> Task<Result<Model<SshRemoteClient>>> {
597    let window = cx.window_handle();
598    let known_password = connection_options.password.clone();
599
600    remote::SshRemoteClient::new(
601        unique_identifier,
602        connection_options,
603        Arc::new(SshClientDelegate {
604            window,
605            ui,
606            known_password,
607        }),
608        cx,
609    )
610}
611
612pub async fn open_ssh_project(
613    connection_options: SshConnectionOptions,
614    paths: Vec<PathBuf>,
615    app_state: Arc<AppState>,
616    open_options: workspace::OpenOptions,
617    nickname: Option<SharedString>,
618    cx: &mut AsyncAppContext,
619) -> Result<()> {
620    let window = if let Some(window) = open_options.replace_window {
621        window
622    } else {
623        let options = cx.update(|cx| (app_state.build_window_options)(None, cx))?;
624        cx.open_window(options, |cx| {
625            let project = project::Project::local(
626                app_state.client.clone(),
627                app_state.node_runtime.clone(),
628                app_state.user_store.clone(),
629                app_state.languages.clone(),
630                app_state.fs.clone(),
631                None,
632                cx,
633            );
634            cx.new_view(|cx| Workspace::new(None, project, app_state.clone(), cx))
635        })?
636    };
637
638    loop {
639        let delegate = window.update(cx, {
640            let connection_options = connection_options.clone();
641            let nickname = nickname.clone();
642            move |workspace, cx| {
643                cx.activate_window();
644                workspace.toggle_modal(cx, |cx| {
645                    SshConnectionModal::new(&connection_options, nickname.clone(), cx)
646                });
647                let ui = workspace
648                    .active_modal::<SshConnectionModal>(cx)
649                    .unwrap()
650                    .read(cx)
651                    .prompt
652                    .clone();
653
654                Arc::new(SshClientDelegate {
655                    window: cx.window_handle(),
656                    ui,
657                    known_password: connection_options.password.clone(),
658                })
659            }
660        })?;
661
662        let did_open_ssh_project = cx
663            .update(|cx| {
664                workspace::open_ssh_project(
665                    window,
666                    connection_options.clone(),
667                    delegate.clone(),
668                    app_state.clone(),
669                    paths.clone(),
670                    cx,
671                )
672            })?
673            .await;
674
675        window
676            .update(cx, |workspace, cx| {
677                if let Some(ui) = workspace.active_modal::<SshConnectionModal>(cx) {
678                    ui.update(cx, |modal, cx| modal.finished(cx))
679                }
680            })
681            .ok();
682
683        if let Err(e) = did_open_ssh_project {
684            log::error!("Failed to open project: {:?}", e);
685            let response = window
686                .update(cx, |_, cx| {
687                    cx.prompt(
688                        PromptLevel::Critical,
689                        "Failed to connect over SSH",
690                        Some(&e.to_string()),
691                        &["Retry", "Ok"],
692                    )
693                })?
694                .await;
695
696            if response == Ok(0) {
697                continue;
698            }
699        }
700
701        break;
702    }
703
704    // Already showed the error to the user
705    Ok(())
706}