copilot_migration.rs

  1use credentials_provider::CredentialsProvider;
  2use gpui::App;
  3use std::path::PathBuf;
  4
  5const COPILOT_CHAT_EXTENSION_ID: &str = "copilot_chat";
  6const COPILOT_CHAT_PROVIDER_ID: &str = "copilot_chat";
  7
  8pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
  9    if extension_id != COPILOT_CHAT_EXTENSION_ID {
 10        return;
 11    }
 12
 13    let credential_key = format!(
 14        "extension-llm-{}:{}",
 15        COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
 16    );
 17
 18    let credentials_provider = <dyn CredentialsProvider>::global(cx);
 19
 20    cx.spawn(async move |cx| {
 21        let existing_credential = credentials_provider
 22            .read_credentials(&credential_key, &cx)
 23            .await
 24            .ok()
 25            .flatten();
 26
 27        if existing_credential.is_some() {
 28            log::debug!("Copilot Chat extension already has credentials, skipping migration");
 29            return;
 30        }
 31
 32        let oauth_token = match read_copilot_oauth_token().await {
 33            Some(token) => token,
 34            None => {
 35                log::debug!("No existing Copilot OAuth token found to migrate");
 36                return;
 37            }
 38        };
 39
 40        log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
 41
 42        match credentials_provider
 43            .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx)
 44            .await
 45        {
 46            Ok(()) => {
 47                log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
 48            }
 49            Err(err) => {
 50                log::error!("Failed to migrate Copilot OAuth token: {}", err);
 51            }
 52        }
 53    })
 54    .detach();
 55}
 56
 57async fn read_copilot_oauth_token() -> Option<String> {
 58    let config_paths = copilot_config_paths();
 59
 60    for path in config_paths {
 61        if let Some(token) = read_oauth_token_from_file(&path).await {
 62            return Some(token);
 63        }
 64    }
 65
 66    None
 67}
 68
 69fn copilot_config_paths() -> Vec<PathBuf> {
 70    let config_dir = if cfg!(target_os = "windows") {
 71        dirs::data_local_dir()
 72    } else {
 73        std::env::var("XDG_CONFIG_HOME")
 74            .map(PathBuf::from)
 75            .ok()
 76            .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
 77    };
 78
 79    let Some(config_dir) = config_dir else {
 80        return Vec::new();
 81    };
 82
 83    let copilot_dir = config_dir.join("github-copilot");
 84
 85    vec![
 86        copilot_dir.join("hosts.json"),
 87        copilot_dir.join("apps.json"),
 88    ]
 89}
 90
 91async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
 92    let contents = match smol::fs::read_to_string(path).await {
 93        Ok(contents) => contents,
 94        Err(_) => return None,
 95    };
 96
 97    extract_oauth_token(&contents, "github.com")
 98}
 99
100fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
101    let value: serde_json::Value = serde_json::from_str(contents).ok()?;
102    let obj = value.as_object()?;
103
104    for (key, value) in obj.iter() {
105        if key.starts_with(domain) {
106            if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
107                return Some(token.to_string());
108            }
109        }
110    }
111
112    None
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_extract_oauth_token() {
121        let contents = r#"{
122            "github.com": {
123                "oauth_token": "ghu_test_token_12345"
124            }
125        }"#;
126
127        let token = extract_oauth_token(contents, "github.com");
128        assert_eq!(token, Some("ghu_test_token_12345".to_string()));
129    }
130
131    #[test]
132    fn test_extract_oauth_token_with_prefix() {
133        let contents = r#"{
134            "github.com:user": {
135                "oauth_token": "ghu_another_token"
136            }
137        }"#;
138
139        let token = extract_oauth_token(contents, "github.com");
140        assert_eq!(token, Some("ghu_another_token".to_string()));
141    }
142
143    #[test]
144    fn test_extract_oauth_token_missing() {
145        let contents = r#"{
146            "gitlab.com": {
147                "oauth_token": "some_token"
148            }
149        }"#;
150
151        let token = extract_oauth_token(contents, "github.com");
152        assert_eq!(token, None);
153    }
154
155    #[test]
156    fn test_extract_oauth_token_invalid_json() {
157        let contents = "not valid json";
158        let token = extract_oauth_token(contents, "github.com");
159        assert_eq!(token, None);
160    }
161}