copilot_migration.rs

  1use credentials_provider::CredentialsProvider;
  2use gpui::App;
  3use std::path::PathBuf;
  4
  5const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
  6const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
  7
  8pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
  9    if extension_id != COPILOT_CHAT_EXTENSION_ID {
 10        return;
 11    }
 12
 13    let credential_key = format!(
 14        "extension-llm-{}:{}",
 15        COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
 16    );
 17
 18    let credentials_provider = <dyn CredentialsProvider>::global(cx);
 19
 20    cx.spawn(async move |cx| {
 21        let existing_credential = credentials_provider
 22            .read_credentials(&credential_key, &cx)
 23            .await
 24            .ok()
 25            .flatten();
 26
 27        if existing_credential.is_some() {
 28            log::debug!("Copilot Chat extension already has credentials, skipping migration");
 29            return;
 30        }
 31
 32        let oauth_token = match read_copilot_oauth_token().await {
 33            Some(token) => token,
 34            None => {
 35                log::debug!("No existing Copilot OAuth token found to migrate");
 36                return;
 37            }
 38        };
 39
 40        log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
 41
 42        match credentials_provider
 43            .write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &cx)
 44            .await
 45        {
 46            Ok(()) => {
 47                log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
 48            }
 49            Err(err) => {
 50                log::error!("Failed to migrate Copilot OAuth token: {}", err);
 51            }
 52        }
 53    })
 54    .detach();
 55}
 56
 57async fn read_copilot_oauth_token() -> Option<String> {
 58    let config_paths = copilot_config_paths();
 59
 60    for path in config_paths {
 61        if let Some(token) = read_oauth_token_from_file(&path).await {
 62            return Some(token);
 63        }
 64    }
 65
 66    None
 67}
 68
 69fn copilot_config_paths() -> Vec<PathBuf> {
 70    let config_dir = if cfg!(target_os = "windows") {
 71        dirs::data_local_dir()
 72    } else {
 73        std::env::var("XDG_CONFIG_HOME")
 74            .map(PathBuf::from)
 75            .ok()
 76            .or_else(|| dirs::home_dir().map(|h| h.join(".config")))
 77    };
 78
 79    let Some(config_dir) = config_dir else {
 80        return Vec::new();
 81    };
 82
 83    let copilot_dir = config_dir.join("github-copilot");
 84
 85    vec![
 86        copilot_dir.join("hosts.json"),
 87        copilot_dir.join("apps.json"),
 88    ]
 89}
 90
 91async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
 92    let contents = match smol::fs::read_to_string(path).await {
 93        Ok(contents) => contents,
 94        Err(_) => return None,
 95    };
 96
 97    extract_oauth_token(&contents, "github.com")
 98}
 99
100fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
101    let value: serde_json::Value = serde_json::from_str(contents).ok()?;
102    let obj = value.as_object()?;
103
104    for (key, value) in obj.iter() {
105        if key.starts_with(domain) {
106            if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
107                return Some(token.to_string());
108            }
109        }
110    }
111
112    None
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use gpui::TestAppContext;
119
120    #[test]
121    fn test_extract_oauth_token_from_hosts_json() {
122        let contents = r#"{
123            "github.com": {
124                "oauth_token": "ghu_test_token_12345"
125            }
126        }"#;
127
128        let token = extract_oauth_token(contents, "github.com");
129        assert_eq!(token, Some("ghu_test_token_12345".to_string()));
130    }
131
132    #[test]
133    fn test_extract_oauth_token_with_user_suffix() {
134        let contents = r#"{
135            "github.com:user": {
136                "oauth_token": "ghu_another_token"
137            }
138        }"#;
139
140        let token = extract_oauth_token(contents, "github.com");
141        assert_eq!(token, Some("ghu_another_token".to_string()));
142    }
143
144    #[test]
145    fn test_extract_oauth_token_wrong_domain() {
146        let contents = r#"{
147            "gitlab.com": {
148                "oauth_token": "some_token"
149            }
150        }"#;
151
152        let token = extract_oauth_token(contents, "github.com");
153        assert_eq!(token, None);
154    }
155
156    #[test]
157    fn test_extract_oauth_token_invalid_json() {
158        let contents = "not valid json";
159        let token = extract_oauth_token(contents, "github.com");
160        assert_eq!(token, None);
161    }
162
163    #[test]
164    fn test_extract_oauth_token_missing_oauth_token_field() {
165        let contents = r#"{
166            "github.com": {
167                "user": "testuser"
168            }
169        }"#;
170
171        let token = extract_oauth_token(contents, "github.com");
172        assert_eq!(token, None);
173    }
174
175    #[test]
176    fn test_extract_oauth_token_multiple_entries_picks_first_match() {
177        let contents = r#"{
178            "gitlab.com": {
179                "oauth_token": "gitlab_token"
180            },
181            "github.com": {
182                "oauth_token": "github_token"
183            }
184        }"#;
185
186        let token = extract_oauth_token(contents, "github.com");
187        assert_eq!(token, Some("github_token".to_string()));
188    }
189
190    #[gpui::test]
191    async fn test_skips_migration_if_extension_already_has_credentials(cx: &mut TestAppContext) {
192        let existing_token = "existing_oauth_token";
193
194        cx.write_credentials(
195            "extension-llm-copilot-chat:copilot-chat",
196            "api_key",
197            existing_token.as_bytes(),
198        );
199
200        cx.update(|cx| {
201            migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
202        });
203
204        cx.run_until_parked();
205
206        let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
207        let (_, password) = credentials.unwrap();
208        assert_eq!(
209            String::from_utf8(password).unwrap(),
210            existing_token,
211            "Should not overwrite existing credentials"
212        );
213    }
214
215    #[gpui::test]
216    async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
217        cx.update(|cx| {
218            migrate_copilot_credentials_if_needed("some-other-extension", cx);
219        });
220
221        cx.run_until_parked();
222
223        let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
224        assert!(
225            credentials.is_none(),
226            "Should not create credentials for other extensions"
227        );
228    }
229
230    #[gpui::test]
231    async fn test_no_migration_when_no_copilot_config_exists(cx: &mut TestAppContext) {
232        cx.update(|cx| {
233            migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
234        });
235
236        cx.run_until_parked();
237
238        let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
239        assert!(
240            credentials.is_none(),
241            "Should not create credentials when no copilot config exists"
242        );
243    }
244}