1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::{Result, anyhow};
6use chrono::DateTime;
7use collections::HashSet;
8use fs::Fs;
9use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
10use gpui::{App, AsyncApp, Global, prelude::*};
11use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
12use itertools::Itertools;
13use paths::home_dir;
14use serde::{Deserialize, Serialize};
15use settings::watch_config_dir;
16
17pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
18pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
19pub const COPILOT_CHAT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
20
21// Copilot's base model; defined by Microsoft in premium requests table
22// This will be moved to the front of the Copilot model list, and will be used for
23// 'fast' requests (e.g. title generation)
24// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
25const DEFAULT_MODEL_ID: &str = "gpt-4.1";
26
27#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum Role {
30 User,
31 Assistant,
32 System,
33}
34
35#[derive(Deserialize)]
36struct ModelSchema {
37 #[serde(deserialize_with = "deserialize_models_skip_errors")]
38 data: Vec<Model>,
39}
40
41fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
42where
43 D: serde::Deserializer<'de>,
44{
45 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
46 let models = raw_values
47 .into_iter()
48 .filter_map(|value| match serde_json::from_value::<Model>(value) {
49 Ok(model) => Some(model),
50 Err(err) => {
51 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
52 None
53 }
54 })
55 .collect();
56
57 Ok(models)
58}
59
60#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
61pub struct Model {
62 capabilities: ModelCapabilities,
63 id: String,
64 name: String,
65 policy: Option<ModelPolicy>,
66 vendor: ModelVendor,
67 model_picker_enabled: bool,
68}
69
70#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
71struct ModelCapabilities {
72 family: String,
73 #[serde(default)]
74 limits: ModelLimits,
75 supports: ModelSupportedFeatures,
76}
77
78#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
79struct ModelLimits {
80 #[serde(default)]
81 max_context_window_tokens: usize,
82 #[serde(default)]
83 max_output_tokens: usize,
84 #[serde(default)]
85 max_prompt_tokens: usize,
86}
87
88#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
89struct ModelPolicy {
90 state: String,
91}
92
93#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
94struct ModelSupportedFeatures {
95 #[serde(default)]
96 streaming: bool,
97 #[serde(default)]
98 tool_calls: bool,
99}
100
101#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
102pub enum ModelVendor {
103 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
104 #[serde(alias = "Azure OpenAI")]
105 OpenAI,
106 Google,
107 Anthropic,
108}
109
110impl Model {
111 pub fn uses_streaming(&self) -> bool {
112 self.capabilities.supports.streaming
113 }
114
115 pub fn id(&self) -> &str {
116 self.id.as_str()
117 }
118
119 pub fn display_name(&self) -> &str {
120 self.name.as_str()
121 }
122
123 pub fn max_token_count(&self) -> usize {
124 self.capabilities.limits.max_prompt_tokens
125 }
126
127 pub fn supports_tools(&self) -> bool {
128 self.capabilities.supports.tool_calls
129 }
130
131 pub fn vendor(&self) -> ModelVendor {
132 self.vendor
133 }
134}
135
136#[derive(Serialize, Deserialize)]
137pub struct Request {
138 pub intent: bool,
139 pub n: usize,
140 pub stream: bool,
141 pub temperature: f32,
142 pub model: String,
143 pub messages: Vec<ChatMessage>,
144 #[serde(default, skip_serializing_if = "Vec::is_empty")]
145 pub tools: Vec<Tool>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub tool_choice: Option<ToolChoice>,
148}
149
150#[derive(Serialize, Deserialize)]
151pub struct Function {
152 pub name: String,
153 pub description: String,
154 pub parameters: serde_json::Value,
155}
156
157#[derive(Serialize, Deserialize)]
158#[serde(tag = "type", rename_all = "snake_case")]
159pub enum Tool {
160 Function { function: Function },
161}
162
163#[derive(Serialize, Deserialize)]
164#[serde(rename_all = "lowercase")]
165pub enum ToolChoice {
166 Auto,
167 Any,
168 None,
169}
170
171#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
172#[serde(tag = "role", rename_all = "lowercase")]
173pub enum ChatMessage {
174 Assistant {
175 content: Option<String>,
176 #[serde(default, skip_serializing_if = "Vec::is_empty")]
177 tool_calls: Vec<ToolCall>,
178 },
179 User {
180 content: String,
181 },
182 System {
183 content: String,
184 },
185 Tool {
186 content: String,
187 tool_call_id: String,
188 },
189}
190
191#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
192pub struct ToolCall {
193 pub id: String,
194 #[serde(flatten)]
195 pub content: ToolCallContent,
196}
197
198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
199#[serde(tag = "type", rename_all = "lowercase")]
200pub enum ToolCallContent {
201 Function { function: FunctionContent },
202}
203
204#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
205pub struct FunctionContent {
206 pub name: String,
207 pub arguments: String,
208}
209
210#[derive(Deserialize, Debug)]
211#[serde(tag = "type", rename_all = "snake_case")]
212pub struct ResponseEvent {
213 pub choices: Vec<ResponseChoice>,
214 pub created: u64,
215 pub id: String,
216}
217
218#[derive(Debug, Deserialize)]
219pub struct ResponseChoice {
220 pub index: usize,
221 pub finish_reason: Option<String>,
222 pub delta: Option<ResponseDelta>,
223 pub message: Option<ResponseDelta>,
224}
225
226#[derive(Debug, Deserialize)]
227pub struct ResponseDelta {
228 pub content: Option<String>,
229 pub role: Option<Role>,
230 #[serde(default)]
231 pub tool_calls: Vec<ToolCallChunk>,
232}
233
234#[derive(Deserialize, Debug, Eq, PartialEq)]
235pub struct ToolCallChunk {
236 pub index: usize,
237 pub id: Option<String>,
238 pub function: Option<FunctionChunk>,
239}
240
241#[derive(Deserialize, Debug, Eq, PartialEq)]
242pub struct FunctionChunk {
243 pub name: Option<String>,
244 pub arguments: Option<String>,
245}
246
247#[derive(Deserialize)]
248struct ApiTokenResponse {
249 token: String,
250 expires_at: i64,
251}
252
253#[derive(Clone)]
254struct ApiToken {
255 api_key: String,
256 expires_at: DateTime<chrono::Utc>,
257}
258
259impl ApiToken {
260 pub fn remaining_seconds(&self) -> i64 {
261 self.expires_at
262 .timestamp()
263 .saturating_sub(chrono::Utc::now().timestamp())
264 }
265}
266
267impl TryFrom<ApiTokenResponse> for ApiToken {
268 type Error = anyhow::Error;
269
270 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
271 let expires_at = DateTime::from_timestamp(response.expires_at, 0)
272 .ok_or_else(|| anyhow!("invalid expires_at"))?;
273
274 Ok(Self {
275 api_key: response.token,
276 expires_at,
277 })
278 }
279}
280
281struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
282
283impl Global for GlobalCopilotChat {}
284
285pub struct CopilotChat {
286 oauth_token: Option<String>,
287 api_token: Option<ApiToken>,
288 models: Option<Vec<Model>>,
289 client: Arc<dyn HttpClient>,
290}
291
292pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
293 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
294 cx.set_global(GlobalCopilotChat(copilot_chat));
295}
296
297pub fn copilot_chat_config_dir() -> &'static PathBuf {
298 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
299
300 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
301 if cfg!(target_os = "windows") {
302 home_dir().join("AppData").join("Local")
303 } else {
304 home_dir().join(".config")
305 }
306 .join("github-copilot")
307 })
308}
309
310fn copilot_chat_config_paths() -> [PathBuf; 2] {
311 let base_dir = copilot_chat_config_dir();
312 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
313}
314
315impl CopilotChat {
316 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
317 cx.try_global::<GlobalCopilotChat>()
318 .map(|model| model.0.clone())
319 }
320
321 pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
322 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
323 let dir_path = copilot_chat_config_dir();
324
325 cx.spawn({
326 let client = client.clone();
327 async move |cx| {
328 let mut parent_watch_rx = watch_config_dir(
329 cx.background_executor(),
330 fs.clone(),
331 dir_path.clone(),
332 config_paths,
333 );
334 while let Some(contents) = parent_watch_rx.next().await {
335 let oauth_token = extract_oauth_token(contents);
336 cx.update(|cx| {
337 if let Some(this) = Self::global(cx).as_ref() {
338 this.update(cx, |this, cx| {
339 this.oauth_token = oauth_token.clone();
340 cx.notify();
341 });
342 }
343 })?;
344
345 if let Some(ref oauth_token) = oauth_token {
346 let api_token = request_api_token(oauth_token, client.clone()).await?;
347 cx.update(|cx| {
348 if let Some(this) = Self::global(cx).as_ref() {
349 this.update(cx, |this, cx| {
350 this.api_token = Some(api_token.clone());
351 cx.notify();
352 });
353 }
354 })?;
355 let models = get_models(api_token.api_key, client.clone()).await?;
356 cx.update(|cx| {
357 if let Some(this) = Self::global(cx).as_ref() {
358 this.update(cx, |this, cx| {
359 this.models = Some(models);
360 cx.notify();
361 });
362 }
363 })?;
364 }
365 }
366 anyhow::Ok(())
367 }
368 })
369 .detach_and_log_err(cx);
370
371 Self {
372 oauth_token: None,
373 api_token: None,
374 models: None,
375 client,
376 }
377 }
378
379 pub fn is_authenticated(&self) -> bool {
380 self.oauth_token.is_some()
381 }
382
383 pub fn models(&self) -> Option<&[Model]> {
384 self.models.as_deref()
385 }
386
387 pub async fn stream_completion(
388 request: Request,
389 mut cx: AsyncApp,
390 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
391 let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
392 return Err(anyhow!("Copilot chat is not enabled"));
393 };
394
395 let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
396 (
397 this.oauth_token.clone(),
398 this.api_token.clone(),
399 this.client.clone(),
400 )
401 })?;
402
403 let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?;
404
405 let token = match api_token {
406 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
407 _ => {
408 let token = request_api_token(&oauth_token, client.clone()).await?;
409 this.update(&mut cx, |this, cx| {
410 this.api_token = Some(token.clone());
411 cx.notify();
412 })?;
413 token
414 }
415 };
416
417 stream_completion(client.clone(), token.api_key, request).await
418 }
419}
420
421async fn get_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
422 let all_models = request_models(api_token, client).await?;
423
424 let mut models: Vec<Model> = all_models
425 .into_iter()
426 .filter(|model| {
427 // Ensure user has access to the model; Policy is present only for models that must be
428 // enabled in the GitHub dashboard
429 model.model_picker_enabled
430 && model
431 .policy
432 .as_ref()
433 .is_none_or(|policy| policy.state == "enabled")
434 })
435 // The first model from the API response, in any given family, appear to be the non-tagged
436 // models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
437 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
438 .collect();
439
440 if let Some(default_model_position) =
441 models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
442 {
443 let default_model = models.remove(default_model_position);
444 models.insert(0, default_model);
445 }
446
447 Ok(models)
448}
449
450async fn request_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
451 let request_builder = HttpRequest::builder()
452 .method(Method::GET)
453 .uri(COPILOT_CHAT_MODELS_URL)
454 .header("Authorization", format!("Bearer {}", api_token))
455 .header("Content-Type", "application/json")
456 .header("Copilot-Integration-Id", "vscode-chat");
457
458 let request = request_builder.body(AsyncBody::empty())?;
459
460 let mut response = client.send(request).await?;
461
462 if response.status().is_success() {
463 let mut body = Vec::new();
464 response.body_mut().read_to_end(&mut body).await?;
465
466 let body_str = std::str::from_utf8(&body)?;
467
468 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
469
470 Ok(models)
471 } else {
472 Err(anyhow!("Failed to request models: {}", response.status()))
473 }
474}
475
476async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
477 let request_builder = HttpRequest::builder()
478 .method(Method::GET)
479 .uri(COPILOT_CHAT_AUTH_URL)
480 .header("Authorization", format!("token {}", oauth_token))
481 .header("Accept", "application/json");
482
483 let request = request_builder.body(AsyncBody::empty())?;
484
485 let mut response = client.send(request).await?;
486
487 if response.status().is_success() {
488 let mut body = Vec::new();
489 response.body_mut().read_to_end(&mut body).await?;
490
491 let body_str = std::str::from_utf8(&body)?;
492
493 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
494 ApiToken::try_from(parsed)
495 } else {
496 let mut body = Vec::new();
497 response.body_mut().read_to_end(&mut body).await?;
498
499 let body_str = std::str::from_utf8(&body)?;
500
501 Err(anyhow!("Failed to request API token: {}", body_str))
502 }
503}
504
505fn extract_oauth_token(contents: String) -> Option<String> {
506 serde_json::from_str::<serde_json::Value>(&contents)
507 .map(|v| {
508 v.as_object().and_then(|obj| {
509 obj.iter().find_map(|(key, value)| {
510 if key.starts_with("github.com") {
511 value["oauth_token"].as_str().map(|v| v.to_string())
512 } else {
513 None
514 }
515 })
516 })
517 })
518 .ok()
519 .flatten()
520}
521
522async fn stream_completion(
523 client: Arc<dyn HttpClient>,
524 api_key: String,
525 request: Request,
526) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
527 let request_builder = HttpRequest::builder()
528 .method(Method::POST)
529 .uri(COPILOT_CHAT_COMPLETION_URL)
530 .header(
531 "Editor-Version",
532 format!(
533 "Zed/{}",
534 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
535 ),
536 )
537 .header("Authorization", format!("Bearer {}", api_key))
538 .header("Content-Type", "application/json")
539 .header("Copilot-Integration-Id", "vscode-chat");
540
541 let is_streaming = request.stream;
542
543 let json = serde_json::to_string(&request)?;
544 let request = request_builder.body(AsyncBody::from(json))?;
545 let mut response = client.send(request).await?;
546
547 if !response.status().is_success() {
548 let mut body = Vec::new();
549 response.body_mut().read_to_end(&mut body).await?;
550 let body_str = std::str::from_utf8(&body)?;
551 return Err(anyhow!(
552 "Failed to connect to API: {} {}",
553 response.status(),
554 body_str
555 ));
556 }
557
558 if is_streaming {
559 let reader = BufReader::new(response.into_body());
560 Ok(reader
561 .lines()
562 .filter_map(|line| async move {
563 match line {
564 Ok(line) => {
565 let line = line.strip_prefix("data: ")?;
566 if line.starts_with("[DONE]") {
567 return None;
568 }
569
570 match serde_json::from_str::<ResponseEvent>(line) {
571 Ok(response) => {
572 if response.choices.is_empty() {
573 None
574 } else {
575 Some(Ok(response))
576 }
577 }
578 Err(error) => Some(Err(anyhow!(error))),
579 }
580 }
581 Err(error) => Some(Err(anyhow!(error))),
582 }
583 })
584 .boxed())
585 } else {
586 let mut body = Vec::new();
587 response.body_mut().read_to_end(&mut body).await?;
588 let body_str = std::str::from_utf8(&body)?;
589 let response: ResponseEvent = serde_json::from_str(body_str)?;
590
591 Ok(futures::stream::once(async move { Ok(response) }).boxed())
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn test_resilient_model_schema_deserialize() {
601 let json = r#"{
602 "data": [
603 {
604 "capabilities": {
605 "family": "gpt-4",
606 "limits": {
607 "max_context_window_tokens": 32768,
608 "max_output_tokens": 4096,
609 "max_prompt_tokens": 32768
610 },
611 "object": "model_capabilities",
612 "supports": { "streaming": true, "tool_calls": true },
613 "tokenizer": "cl100k_base",
614 "type": "chat"
615 },
616 "id": "gpt-4",
617 "model_picker_enabled": false,
618 "name": "GPT 4",
619 "object": "model",
620 "preview": false,
621 "vendor": "Azure OpenAI",
622 "version": "gpt-4-0613"
623 },
624 {
625 "some-unknown-field": 123
626 },
627 {
628 "capabilities": {
629 "family": "claude-3.7-sonnet",
630 "limits": {
631 "max_context_window_tokens": 200000,
632 "max_output_tokens": 16384,
633 "max_prompt_tokens": 90000,
634 "vision": {
635 "max_prompt_image_size": 3145728,
636 "max_prompt_images": 1,
637 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
638 }
639 },
640 "object": "model_capabilities",
641 "supports": {
642 "parallel_tool_calls": true,
643 "streaming": true,
644 "tool_calls": true,
645 "vision": true
646 },
647 "tokenizer": "o200k_base",
648 "type": "chat"
649 },
650 "id": "claude-3.7-sonnet",
651 "model_picker_enabled": true,
652 "name": "Claude 3.7 Sonnet",
653 "object": "model",
654 "policy": {
655 "state": "enabled",
656 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
657 },
658 "preview": false,
659 "vendor": "Anthropic",
660 "version": "claude-3.7-sonnet"
661 }
662 ],
663 "object": "list"
664 }"#;
665
666 let schema: ModelSchema = serde_json::from_str(&json).unwrap();
667
668 assert_eq!(schema.data.len(), 2);
669 assert_eq!(schema.data[0].id, "gpt-4");
670 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
671 }
672}