download_file_capability.rs

  1use serde::{Deserialize, Serialize};
  2use url::Url;
  3
  4#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
  5#[serde(rename_all = "snake_case")]
  6pub struct DownloadFileCapability {
  7    pub host: String,
  8    pub path: Vec<String>,
  9}
 10
 11impl DownloadFileCapability {
 12    /// Returns whether the capability allows downloading a file from the given URL.
 13    pub fn allows(&self, url: &Url) -> bool {
 14        let Some(desired_host) = url.host_str() else {
 15            return false;
 16        };
 17
 18        let Some(desired_path) = url.path_segments() else {
 19            return false;
 20        };
 21        let desired_path = desired_path.collect::<Vec<_>>();
 22
 23        if self.host != desired_host && self.host != "*" {
 24            return false;
 25        }
 26
 27        for (ix, path_segment) in self.path.iter().enumerate() {
 28            if path_segment == "**" {
 29                return true;
 30            }
 31
 32            if ix >= desired_path.len() {
 33                return false;
 34            }
 35
 36            if path_segment != "*" && path_segment != desired_path[ix] {
 37                return false;
 38            }
 39        }
 40
 41        if self.path.len() < desired_path.len() {
 42            return false;
 43        }
 44
 45        true
 46    }
 47}
 48
 49#[cfg(test)]
 50mod tests {
 51    use pretty_assertions::assert_eq;
 52
 53    use super::*;
 54
 55    #[test]
 56    fn test_allows() {
 57        let capability = DownloadFileCapability {
 58            host: "*".to_string(),
 59            path: vec!["**".to_string()],
 60        };
 61        assert_eq!(
 62            capability.allows(&"https://example.com/some/path".parse().unwrap()),
 63            true
 64        );
 65
 66        let capability = DownloadFileCapability {
 67            host: "github.com".to_string(),
 68            path: vec!["**".to_string()],
 69        };
 70        assert_eq!(
 71            capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()),
 72            true
 73        );
 74        assert_eq!(
 75            capability.allows(
 76                &"https://fake-github.com/some-owner/some-repo"
 77                    .parse()
 78                    .unwrap()
 79            ),
 80            false
 81        );
 82
 83        let capability = DownloadFileCapability {
 84            host: "github.com".to_string(),
 85            path: vec!["specific-owner".to_string(), "*".to_string()],
 86        };
 87        assert_eq!(
 88            capability.allows(&"https://github.com/some-owner/some-repo".parse().unwrap()),
 89            false
 90        );
 91        assert_eq!(
 92            capability.allows(
 93                &"https://github.com/specific-owner/some-repo"
 94                    .parse()
 95                    .unwrap()
 96            ),
 97            true
 98        );
 99
100        let capability = DownloadFileCapability {
101            host: "github.com".to_string(),
102            path: vec!["specific-owner".to_string(), "*".to_string()],
103        };
104        assert_eq!(
105            capability.allows(
106                &"https://github.com/some-owner/some-repo/extra"
107                    .parse()
108                    .unwrap()
109            ),
110            false
111        );
112        assert_eq!(
113            capability.allows(
114                &"https://github.com/specific-owner/some-repo/extra"
115                    .parse()
116                    .unwrap()
117            ),
118            false
119        );
120    }
121}