Allow updater to check for updates after downloading one (#31066)

Joseph T. Lyons created

This PR brings back https://github.com/zed-industries/zed/pull/30969 and
adds some initial testing.

https://github.com/zed-industries/zed/pull/30969 did indeed allow Zed to
continue doing downloads after downloading one, but it introduced a bug
where Zed would download a new binary every time it polled, even if the
version was the same as the running instance.

This code could use a refactor to allow more / better testing, but this
is a start.

Release Notes:

- N/A

Change summary

crates/activity_indicator/src/activity_indicator.rs |   2 
crates/auto_update/Cargo.toml                       |   2 
crates/auto_update/src/auto_update.rs               | 427 +++++++++++++-
3 files changed, 380 insertions(+), 51 deletions(-)

Detailed changes

crates/activity_indicator/src/activity_indicator.rs 🔗

@@ -485,7 +485,7 @@ impl ActivityIndicator {
                         this.dismiss_error_message(&DismissErrorMessage, window, cx)
                     })),
                 }),
-                AutoUpdateStatus::Updated { binary_path } => Some(Content {
+                AutoUpdateStatus::Updated { binary_path, .. } => Some(Content {
                     icon: None,
                     message: "Click to restart and update Zed".to_string(),
                     on_click: Some(Arc::new({

crates/auto_update/Cargo.toml 🔗

@@ -16,7 +16,7 @@ doctest = false
 anyhow.workspace = true
 client.workspace = true
 db.workspace = true
-gpui.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
 http_client.workspace = true
 log.workspace = true
 paths.workspace = true

crates/auto_update/src/auto_update.rs 🔗

@@ -39,13 +39,22 @@ struct UpdateRequestBody {
     destination: &'static str,
 }
 
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum VersionCheckType {
+    Sha(String),
+    Semantic(SemanticVersion),
+}
+
 #[derive(Clone, PartialEq, Eq)]
 pub enum AutoUpdateStatus {
     Idle,
     Checking,
     Downloading,
     Installing,
-    Updated { binary_path: PathBuf },
+    Updated {
+        binary_path: PathBuf,
+        version: VersionCheckType,
+    },
     Errored,
 }
 
@@ -62,7 +71,7 @@ pub struct AutoUpdater {
     pending_poll: Option<Task<Option<()>>>,
 }
 
-#[derive(Deserialize, Debug)]
+#[derive(Deserialize, Clone, Debug)]
 pub struct JsonRelease {
     pub version: String,
     pub url: String,
@@ -307,7 +316,7 @@ impl AutoUpdater {
     }
 
     pub fn poll(&mut self, cx: &mut Context<Self>) {
-        if self.pending_poll.is_some() || self.status.is_updated() {
+        if self.pending_poll.is_some() {
             return;
         }
 
@@ -483,35 +492,38 @@ impl AutoUpdater {
     }
 
     async fn update(this: Entity<Self>, mut cx: AsyncApp) -> Result<()> {
-        let (client, current_version, release_channel) = this.update(&mut cx, |this, cx| {
-            this.status = AutoUpdateStatus::Checking;
-            cx.notify();
-            (
-                this.http_client.clone(),
-                this.current_version,
-                ReleaseChannel::try_global(cx),
-            )
-        })?;
-
-        let release =
-            Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
-
-        let should_download = match *RELEASE_CHANNEL {
-            ReleaseChannel::Nightly => cx
-                .update(|cx| AppCommitSha::try_global(cx).map(|sha| release.version != sha.0))
-                .ok()
-                .flatten()
-                .unwrap_or(true),
-            _ => release.version.parse::<SemanticVersion>()? > current_version,
-        };
-
-        if !should_download {
+        let (client, installed_version, status, release_channel) =
             this.update(&mut cx, |this, cx| {
-                this.status = AutoUpdateStatus::Idle;
+                this.status = AutoUpdateStatus::Checking;
                 cx.notify();
+                (
+                    this.http_client.clone(),
+                    this.current_version,
+                    this.status.clone(),
+                    ReleaseChannel::try_global(cx),
+                )
             })?;
-            return Ok(());
-        }
+
+        let fetched_release_data =
+            Self::get_latest_release(&this, "zed", OS, ARCH, release_channel, &mut cx).await?;
+        let fetched_version = fetched_release_data.clone().version;
+        let app_commit_sha = cx.update(|cx| AppCommitSha::try_global(cx).map(|sha| sha.0));
+        let newer_version = Self::check_for_newer_version(
+            *RELEASE_CHANNEL,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_version,
+        )?;
+
+        let Some(newer_version) = newer_version else {
+            return this.update(&mut cx, |this, cx| {
+                if !matches!(this.status, AutoUpdateStatus::Updated { .. }) {
+                    this.status = AutoUpdateStatus::Idle;
+                    cx.notify();
+                }
+            });
+        };
 
         this.update(&mut cx, |this, cx| {
             this.status = AutoUpdateStatus::Downloading;
@@ -519,6 +531,71 @@ impl AutoUpdater {
         })?;
 
         let installer_dir = InstallerDir::new().await?;
+        let target_path = Self::target_path(&installer_dir).await?;
+        download_release(&target_path, fetched_release_data, client, &cx).await?;
+
+        this.update(&mut cx, |this, cx| {
+            this.status = AutoUpdateStatus::Installing;
+            cx.notify();
+        })?;
+
+        let binary_path = Self::binary_path(installer_dir, target_path, &cx).await?;
+
+        this.update(&mut cx, |this, cx| {
+            this.set_should_show_update_notification(true, cx)
+                .detach_and_log_err(cx);
+            this.status = AutoUpdateStatus::Updated {
+                binary_path,
+                version: newer_version,
+            };
+            cx.notify();
+        })
+    }
+
+    fn check_for_newer_version(
+        release_channel: ReleaseChannel,
+        app_commit_sha: Result<Option<String>>,
+        installed_version: SemanticVersion,
+        status: AutoUpdateStatus,
+        fetched_version: String,
+    ) -> Result<Option<VersionCheckType>> {
+        let parsed_fetched_version = fetched_version.parse::<SemanticVersion>();
+
+        if let AutoUpdateStatus::Updated { version, .. } = status {
+            match version {
+                VersionCheckType::Sha(cached_version) => {
+                    let should_download = fetched_version != cached_version;
+                    let newer_version =
+                        should_download.then(|| VersionCheckType::Sha(fetched_version));
+                    return Ok(newer_version);
+                }
+                VersionCheckType::Semantic(cached_version) => {
+                    return Self::check_for_newer_version_non_nightly(
+                        cached_version,
+                        parsed_fetched_version?,
+                    );
+                }
+            }
+        }
+
+        match release_channel {
+            ReleaseChannel::Nightly => {
+                let should_download = app_commit_sha
+                    .ok()
+                    .flatten()
+                    .map(|sha| fetched_version != sha)
+                    .unwrap_or(true);
+                let newer_version = should_download.then(|| VersionCheckType::Sha(fetched_version));
+                Ok(newer_version)
+            }
+            _ => Self::check_for_newer_version_non_nightly(
+                installed_version,
+                parsed_fetched_version?,
+            ),
+        }
+    }
+
+    async fn target_path(installer_dir: &InstallerDir) -> Result<PathBuf> {
         let filename = match OS {
             "macos" => anyhow::Ok("Zed.dmg"),
             "linux" => Ok("zed.tar.gz"),
@@ -532,29 +609,29 @@ impl AutoUpdater {
             "Aborting. Could not find rsync which is required for auto-updates."
         );
 
-        let downloaded_asset = installer_dir.path().join(filename);
-        download_release(&downloaded_asset, release, client, &cx).await?;
-
-        this.update(&mut cx, |this, cx| {
-            this.status = AutoUpdateStatus::Installing;
-            cx.notify();
-        })?;
+        Ok(installer_dir.path().join(filename))
+    }
 
-        let binary_path = match OS {
-            "macos" => install_release_macos(&installer_dir, downloaded_asset, &cx).await,
-            "linux" => install_release_linux(&installer_dir, downloaded_asset, &cx).await,
-            "windows" => install_release_windows(downloaded_asset).await,
+    async fn binary_path(
+        installer_dir: InstallerDir,
+        target_path: PathBuf,
+        cx: &AsyncApp,
+    ) -> Result<PathBuf> {
+        match OS {
+            "macos" => install_release_macos(&installer_dir, target_path, cx).await,
+            "linux" => install_release_linux(&installer_dir, target_path, cx).await,
+            "windows" => install_release_windows(target_path).await,
             unsupported_os => anyhow::bail!("not supported: {unsupported_os}"),
-        }?;
-
-        this.update(&mut cx, |this, cx| {
-            this.set_should_show_update_notification(true, cx)
-                .detach_and_log_err(cx);
-            this.status = AutoUpdateStatus::Updated { binary_path };
-            cx.notify();
-        })?;
+        }
+    }
 
-        Ok(())
+    fn check_for_newer_version_non_nightly(
+        installed_version: SemanticVersion,
+        fetched_version: SemanticVersion,
+    ) -> Result<Option<VersionCheckType>> {
+        let should_download = fetched_version > installed_version;
+        let newer_version = should_download.then(|| VersionCheckType::Semantic(fetched_version));
+        Ok(newer_version)
     }
 
     pub fn set_should_show_update_notification(
@@ -829,3 +906,255 @@ pub fn check_pending_installation() -> bool {
     }
     false
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_stable_does_not_update_when_fetched_version_is_not_higher() {
+        let release_channel = ReleaseChannel::Stable;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Idle;
+        let fetched_version = SemanticVersion::new(1, 0, 0);
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_version.to_string(),
+        );
+
+        assert_eq!(newer_version.unwrap(), None);
+    }
+
+    #[test]
+    fn test_stable_does_update_when_fetched_version_is_higher() {
+        let release_channel = ReleaseChannel::Stable;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Idle;
+        let fetched_version = SemanticVersion::new(1, 0, 1);
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_version.to_string(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Semantic(fetched_version))
+        );
+    }
+
+    #[test]
+    fn test_stable_does_not_update_when_fetched_version_is_not_higher_than_cached() {
+        let release_channel = ReleaseChannel::Stable;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
+        };
+        let fetched_version = SemanticVersion::new(1, 0, 1);
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_version.to_string(),
+        );
+
+        assert_eq!(newer_version.unwrap(), None);
+    }
+
+    #[test]
+    fn test_stable_does_update_when_fetched_version_is_higher_than_cached() {
+        let release_channel = ReleaseChannel::Stable;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Semantic(SemanticVersion::new(1, 0, 1)),
+        };
+        let fetched_version = SemanticVersion::new(1, 0, 2);
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_version.to_string(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Semantic(fetched_version))
+        );
+    }
+
+    #[test]
+    fn test_nightly_does_not_update_when_fetched_sha_is_same() {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Idle;
+        let fetched_sha = "a".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha,
+        );
+
+        assert_eq!(newer_version.unwrap(), None);
+    }
+
+    #[test]
+    fn test_nightly_does_update_when_fetched_sha_is_not_same() {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Idle;
+        let fetched_sha = "b".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha.clone(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Sha(fetched_sha))
+        );
+    }
+
+    #[test]
+    fn test_nightly_does_not_update_when_fetched_sha_is_same_as_cached() {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Sha("b".to_string()),
+        };
+        let fetched_sha = "b".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha,
+        );
+
+        assert_eq!(newer_version.unwrap(), None);
+    }
+
+    #[test]
+    fn test_nightly_does_update_when_fetched_sha_is_not_same_as_cached() {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(Some("a".to_string()));
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Sha("b".to_string()),
+        };
+        let fetched_sha = "c".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha.clone(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Sha(fetched_sha))
+        );
+    }
+
+    #[test]
+    fn test_nightly_does_update_when_installed_versions_sha_cannot_be_retrieved() {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(None);
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Idle;
+        let fetched_sha = "a".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha.clone(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Sha(fetched_sha))
+        );
+    }
+
+    #[test]
+    fn test_nightly_does_not_update_when_cached_update_is_same_as_fetched_and_installed_versions_sha_cannot_be_retrieved()
+     {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(None);
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Sha("b".to_string()),
+        };
+        let fetched_sha = "b".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha,
+        );
+
+        assert_eq!(newer_version.unwrap(), None);
+    }
+
+    #[test]
+    fn test_nightly_does_update_when_cached_update_is_not_same_as_fetched_and_installed_versions_sha_cannot_be_retrieved()
+     {
+        let release_channel = ReleaseChannel::Nightly;
+        let app_commit_sha = Ok(None);
+        let installed_version = SemanticVersion::new(1, 0, 0);
+        let status = AutoUpdateStatus::Updated {
+            binary_path: PathBuf::new(),
+            version: VersionCheckType::Sha("b".to_string()),
+        };
+        let fetched_sha = "c".to_string();
+
+        let newer_version = AutoUpdater::check_for_newer_version(
+            release_channel,
+            app_commit_sha,
+            installed_version,
+            status,
+            fetched_sha.clone(),
+        );
+
+        assert_eq!(
+            newer_version.unwrap(),
+            Some(VersionCheckType::Sha(fetched_sha))
+        );
+    }
+}