copilot_chat.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3use std::sync::OnceLock;
  4
  5use anyhow::Context as _;
  6use anyhow::{Result, anyhow};
  7use chrono::DateTime;
  8use collections::HashSet;
  9use fs::Fs;
 10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 11use gpui::{App, AsyncApp, Global, prelude::*};
 12use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 13use itertools::Itertools;
 14use paths::home_dir;
 15use serde::{Deserialize, Serialize};
 16use settings::watch_config_dir;
 17
 18pub const COPILOT_CHAT_COMPLETION_URL: &str = "https://api.githubcopilot.com/chat/completions";
 19pub const COPILOT_CHAT_AUTH_URL: &str = "https://api.github.com/copilot_internal/v2/token";
 20pub const COPILOT_CHAT_MODELS_URL: &str = "https://api.githubcopilot.com/models";
 21
 22// Copilot's base model; defined by Microsoft in premium requests table
 23// This will be moved to the front of the Copilot model list, and will be used for
 24// 'fast' requests (e.g. title generation)
 25// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
 26const DEFAULT_MODEL_ID: &str = "gpt-4.1";
 27
 28#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 29#[serde(rename_all = "lowercase")]
 30pub enum Role {
 31    User,
 32    Assistant,
 33    System,
 34}
 35
 36#[derive(Deserialize)]
 37struct ModelSchema {
 38    #[serde(deserialize_with = "deserialize_models_skip_errors")]
 39    data: Vec<Model>,
 40}
 41
 42fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
 43where
 44    D: serde::Deserializer<'de>,
 45{
 46    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
 47    let models = raw_values
 48        .into_iter()
 49        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 50            Ok(model) => Some(model),
 51            Err(err) => {
 52                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 53                None
 54            }
 55        })
 56        .collect();
 57
 58    Ok(models)
 59}
 60
 61#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 62pub struct Model {
 63    capabilities: ModelCapabilities,
 64    id: String,
 65    name: String,
 66    policy: Option<ModelPolicy>,
 67    vendor: ModelVendor,
 68    model_picker_enabled: bool,
 69}
 70
 71#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 72struct ModelCapabilities {
 73    family: String,
 74    #[serde(default)]
 75    limits: ModelLimits,
 76    supports: ModelSupportedFeatures,
 77}
 78
 79#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 80struct ModelLimits {
 81    #[serde(default)]
 82    max_context_window_tokens: usize,
 83    #[serde(default)]
 84    max_output_tokens: usize,
 85    #[serde(default)]
 86    max_prompt_tokens: usize,
 87}
 88
 89#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 90struct ModelPolicy {
 91    state: String,
 92}
 93
 94#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 95struct ModelSupportedFeatures {
 96    #[serde(default)]
 97    streaming: bool,
 98    #[serde(default)]
 99    tool_calls: bool,
100    #[serde(default)]
101    parallel_tool_calls: bool,
102    #[serde(default)]
103    vision: bool,
104}
105
106#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
107pub enum ModelVendor {
108    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
109    #[serde(alias = "Azure OpenAI")]
110    OpenAI,
111    Google,
112    Anthropic,
113}
114
115#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
116#[serde(tag = "type")]
117pub enum ChatMessagePart {
118    #[serde(rename = "text")]
119    Text { text: String },
120    #[serde(rename = "image_url")]
121    Image { image_url: ImageUrl },
122}
123
124#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
125pub struct ImageUrl {
126    pub url: String,
127}
128
129impl Model {
130    pub fn uses_streaming(&self) -> bool {
131        self.capabilities.supports.streaming
132    }
133
134    pub fn id(&self) -> &str {
135        self.id.as_str()
136    }
137
138    pub fn display_name(&self) -> &str {
139        self.name.as_str()
140    }
141
142    pub fn max_token_count(&self) -> usize {
143        self.capabilities.limits.max_prompt_tokens
144    }
145
146    pub fn supports_tools(&self) -> bool {
147        self.capabilities.supports.tool_calls
148    }
149
150    pub fn vendor(&self) -> ModelVendor {
151        self.vendor
152    }
153
154    pub fn supports_vision(&self) -> bool {
155        self.capabilities.supports.vision
156    }
157
158    pub fn supports_parallel_tool_calls(&self) -> bool {
159        self.capabilities.supports.parallel_tool_calls
160    }
161}
162
163#[derive(Serialize, Deserialize)]
164pub struct Request {
165    pub intent: bool,
166    pub n: usize,
167    pub stream: bool,
168    pub temperature: f32,
169    pub model: String,
170    pub messages: Vec<ChatMessage>,
171    #[serde(default, skip_serializing_if = "Vec::is_empty")]
172    pub tools: Vec<Tool>,
173    #[serde(default, skip_serializing_if = "Option::is_none")]
174    pub tool_choice: Option<ToolChoice>,
175}
176
177#[derive(Serialize, Deserialize)]
178pub struct Function {
179    pub name: String,
180    pub description: String,
181    pub parameters: serde_json::Value,
182}
183
184#[derive(Serialize, Deserialize)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum Tool {
187    Function { function: Function },
188}
189
190#[derive(Serialize, Deserialize)]
191#[serde(rename_all = "lowercase")]
192pub enum ToolChoice {
193    Auto,
194    Any,
195    None,
196}
197
198#[derive(Serialize, Deserialize, Debug)]
199#[serde(tag = "role", rename_all = "lowercase")]
200pub enum ChatMessage {
201    Assistant {
202        content: ChatMessageContent,
203        #[serde(default, skip_serializing_if = "Vec::is_empty")]
204        tool_calls: Vec<ToolCall>,
205    },
206    User {
207        content: ChatMessageContent,
208    },
209    System {
210        content: String,
211    },
212    Tool {
213        content: ChatMessageContent,
214        tool_call_id: String,
215    },
216}
217
218#[derive(Debug, Serialize, Deserialize)]
219#[serde(untagged)]
220pub enum ChatMessageContent {
221    Plain(String),
222    Multipart(Vec<ChatMessagePart>),
223}
224
225impl ChatMessageContent {
226    pub fn empty() -> Self {
227        ChatMessageContent::Multipart(vec![])
228    }
229}
230
231impl From<Vec<ChatMessagePart>> for ChatMessageContent {
232    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
233        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
234            ChatMessageContent::Plain(std::mem::take(text))
235        } else {
236            ChatMessageContent::Multipart(parts)
237        }
238    }
239}
240
241impl From<String> for ChatMessageContent {
242    fn from(text: String) -> Self {
243        ChatMessageContent::Plain(text)
244    }
245}
246
247#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
248pub struct ToolCall {
249    pub id: String,
250    #[serde(flatten)]
251    pub content: ToolCallContent,
252}
253
254#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
255#[serde(tag = "type", rename_all = "lowercase")]
256pub enum ToolCallContent {
257    Function { function: FunctionContent },
258}
259
260#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
261pub struct FunctionContent {
262    pub name: String,
263    pub arguments: String,
264}
265
266#[derive(Deserialize, Debug)]
267#[serde(tag = "type", rename_all = "snake_case")]
268pub struct ResponseEvent {
269    pub choices: Vec<ResponseChoice>,
270    pub id: String,
271}
272
273#[derive(Debug, Deserialize)]
274pub struct ResponseChoice {
275    pub index: usize,
276    pub finish_reason: Option<String>,
277    pub delta: Option<ResponseDelta>,
278    pub message: Option<ResponseDelta>,
279}
280
281#[derive(Debug, Deserialize)]
282pub struct ResponseDelta {
283    pub content: Option<String>,
284    pub role: Option<Role>,
285    #[serde(default)]
286    pub tool_calls: Vec<ToolCallChunk>,
287}
288
289#[derive(Deserialize, Debug, Eq, PartialEq)]
290pub struct ToolCallChunk {
291    pub index: usize,
292    pub id: Option<String>,
293    pub function: Option<FunctionChunk>,
294}
295
296#[derive(Deserialize, Debug, Eq, PartialEq)]
297pub struct FunctionChunk {
298    pub name: Option<String>,
299    pub arguments: Option<String>,
300}
301
302#[derive(Deserialize)]
303struct ApiTokenResponse {
304    token: String,
305    expires_at: i64,
306}
307
308#[derive(Clone)]
309struct ApiToken {
310    api_key: String,
311    expires_at: DateTime<chrono::Utc>,
312}
313
314impl ApiToken {
315    pub fn remaining_seconds(&self) -> i64 {
316        self.expires_at
317            .timestamp()
318            .saturating_sub(chrono::Utc::now().timestamp())
319    }
320}
321
322impl TryFrom<ApiTokenResponse> for ApiToken {
323    type Error = anyhow::Error;
324
325    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
326        let expires_at =
327            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
328
329        Ok(Self {
330            api_key: response.token,
331            expires_at,
332        })
333    }
334}
335
336struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
337
338impl Global for GlobalCopilotChat {}
339
340pub struct CopilotChat {
341    oauth_token: Option<String>,
342    api_token: Option<ApiToken>,
343    models: Option<Vec<Model>>,
344    client: Arc<dyn HttpClient>,
345}
346
347pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
348    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
349    cx.set_global(GlobalCopilotChat(copilot_chat));
350}
351
352pub fn copilot_chat_config_dir() -> &'static PathBuf {
353    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
354
355    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
356        if cfg!(target_os = "windows") {
357            home_dir().join("AppData").join("Local")
358        } else {
359            home_dir().join(".config")
360        }
361        .join("github-copilot")
362    })
363}
364
365fn copilot_chat_config_paths() -> [PathBuf; 2] {
366    let base_dir = copilot_chat_config_dir();
367    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
368}
369
370impl CopilotChat {
371    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
372        cx.try_global::<GlobalCopilotChat>()
373            .map(|model| model.0.clone())
374    }
375
376    pub fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &App) -> Self {
377        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
378        let dir_path = copilot_chat_config_dir();
379
380        cx.spawn({
381            let client = client.clone();
382            async move |cx| {
383                let mut parent_watch_rx = watch_config_dir(
384                    cx.background_executor(),
385                    fs.clone(),
386                    dir_path.clone(),
387                    config_paths,
388                );
389                while let Some(contents) = parent_watch_rx.next().await {
390                    let oauth_token = extract_oauth_token(contents);
391                    cx.update(|cx| {
392                        if let Some(this) = Self::global(cx).as_ref() {
393                            this.update(cx, |this, cx| {
394                                this.oauth_token = oauth_token.clone();
395                                cx.notify();
396                            });
397                        }
398                    })?;
399
400                    if let Some(ref oauth_token) = oauth_token {
401                        let api_token = request_api_token(oauth_token, client.clone()).await?;
402                        cx.update(|cx| {
403                            if let Some(this) = Self::global(cx).as_ref() {
404                                this.update(cx, |this, cx| {
405                                    this.api_token = Some(api_token.clone());
406                                    cx.notify();
407                                });
408                            }
409                        })?;
410                        let models = get_models(api_token.api_key, client.clone()).await?;
411                        cx.update(|cx| {
412                            if let Some(this) = Self::global(cx).as_ref() {
413                                this.update(cx, |this, cx| {
414                                    this.models = Some(models);
415                                    cx.notify();
416                                });
417                            }
418                        })?;
419                    }
420                }
421                anyhow::Ok(())
422            }
423        })
424        .detach_and_log_err(cx);
425
426        Self {
427            oauth_token: None,
428            api_token: None,
429            models: None,
430            client,
431        }
432    }
433
434    pub fn is_authenticated(&self) -> bool {
435        self.oauth_token.is_some()
436    }
437
438    pub fn models(&self) -> Option<&[Model]> {
439        self.models.as_deref()
440    }
441
442    pub async fn stream_completion(
443        request: Request,
444        mut cx: AsyncApp,
445    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
446        let this = cx
447            .update(|cx| Self::global(cx))
448            .ok()
449            .flatten()
450            .context("Copilot chat is not enabled")?;
451
452        let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
453            (
454                this.oauth_token.clone(),
455                this.api_token.clone(),
456                this.client.clone(),
457            )
458        })?;
459
460        let oauth_token = oauth_token.context("No OAuth token available")?;
461
462        let token = match api_token {
463            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
464            _ => {
465                let token = request_api_token(&oauth_token, client.clone()).await?;
466                this.update(&mut cx, |this, cx| {
467                    this.api_token = Some(token.clone());
468                    cx.notify();
469                })?;
470                token
471            }
472        };
473
474        stream_completion(client.clone(), token.api_key, request).await
475    }
476}
477
478async fn get_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
479    let all_models = request_models(api_token, client).await?;
480
481    let mut models: Vec<Model> = all_models
482        .into_iter()
483        .filter(|model| {
484            // Ensure user has access to the model; Policy is present only for models that must be
485            // enabled in the GitHub dashboard
486            model.model_picker_enabled
487                && model
488                    .policy
489                    .as_ref()
490                    .is_none_or(|policy| policy.state == "enabled")
491        })
492        // The first model from the API response, in any given family, appear to be the non-tagged
493        // models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
494        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
495        .collect();
496
497    if let Some(default_model_position) =
498        models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
499    {
500        let default_model = models.remove(default_model_position);
501        models.insert(0, default_model);
502    }
503
504    Ok(models)
505}
506
507async fn request_models(api_token: String, client: Arc<dyn HttpClient>) -> Result<Vec<Model>> {
508    let request_builder = HttpRequest::builder()
509        .method(Method::GET)
510        .uri(COPILOT_CHAT_MODELS_URL)
511        .header("Authorization", format!("Bearer {}", api_token))
512        .header("Content-Type", "application/json")
513        .header("Copilot-Integration-Id", "vscode-chat");
514
515    let request = request_builder.body(AsyncBody::empty())?;
516
517    let mut response = client.send(request).await?;
518
519    anyhow::ensure!(
520        response.status().is_success(),
521        "Failed to request models: {}",
522        response.status()
523    );
524    let mut body = Vec::new();
525    response.body_mut().read_to_end(&mut body).await?;
526
527    let body_str = std::str::from_utf8(&body)?;
528
529    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
530
531    Ok(models)
532}
533
534async fn request_api_token(oauth_token: &str, client: Arc<dyn HttpClient>) -> Result<ApiToken> {
535    let request_builder = HttpRequest::builder()
536        .method(Method::GET)
537        .uri(COPILOT_CHAT_AUTH_URL)
538        .header("Authorization", format!("token {}", oauth_token))
539        .header("Accept", "application/json");
540
541    let request = request_builder.body(AsyncBody::empty())?;
542
543    let mut response = client.send(request).await?;
544
545    if response.status().is_success() {
546        let mut body = Vec::new();
547        response.body_mut().read_to_end(&mut body).await?;
548
549        let body_str = std::str::from_utf8(&body)?;
550
551        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
552        ApiToken::try_from(parsed)
553    } else {
554        let mut body = Vec::new();
555        response.body_mut().read_to_end(&mut body).await?;
556
557        let body_str = std::str::from_utf8(&body)?;
558        anyhow::bail!("Failed to request API token: {body_str}");
559    }
560}
561
562fn extract_oauth_token(contents: String) -> Option<String> {
563    serde_json::from_str::<serde_json::Value>(&contents)
564        .map(|v| {
565            v.as_object().and_then(|obj| {
566                obj.iter().find_map(|(key, value)| {
567                    if key.starts_with("github.com") {
568                        value["oauth_token"].as_str().map(|v| v.to_string())
569                    } else {
570                        None
571                    }
572                })
573            })
574        })
575        .ok()
576        .flatten()
577}
578
579async fn stream_completion(
580    client: Arc<dyn HttpClient>,
581    api_key: String,
582    request: Request,
583) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
584    let is_vision_request = request.messages.last().map_or(false, |message| match message {
585        ChatMessage::User { content }
586        | ChatMessage::Assistant { content, .. }
587        | ChatMessage::Tool { content, .. } => {
588            matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
589        }
590        _ => false,
591    });
592
593    let request_builder = HttpRequest::builder()
594        .method(Method::POST)
595        .uri(COPILOT_CHAT_COMPLETION_URL)
596        .header(
597            "Editor-Version",
598            format!(
599                "Zed/{}",
600                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
601            ),
602        )
603        .header("Authorization", format!("Bearer {}", api_key))
604        .header("Content-Type", "application/json")
605        .header("Copilot-Integration-Id", "vscode-chat")
606        .header("Copilot-Vision-Request", is_vision_request.to_string());
607
608    let is_streaming = request.stream;
609
610    let json = serde_json::to_string(&request)?;
611    let request = request_builder.body(AsyncBody::from(json))?;
612    let mut response = client.send(request).await?;
613
614    if !response.status().is_success() {
615        let mut body = Vec::new();
616        response.body_mut().read_to_end(&mut body).await?;
617        let body_str = std::str::from_utf8(&body)?;
618        anyhow::bail!(
619            "Failed to connect to API: {} {}",
620            response.status(),
621            body_str
622        );
623    }
624
625    if is_streaming {
626        let reader = BufReader::new(response.into_body());
627        Ok(reader
628            .lines()
629            .filter_map(|line| async move {
630                match line {
631                    Ok(line) => {
632                        let line = line.strip_prefix("data: ")?;
633                        if line.starts_with("[DONE]") {
634                            return None;
635                        }
636
637                        match serde_json::from_str::<ResponseEvent>(line) {
638                            Ok(response) => {
639                                if response.choices.is_empty() {
640                                    None
641                                } else {
642                                    Some(Ok(response))
643                                }
644                            }
645                            Err(error) => Some(Err(anyhow!(error))),
646                        }
647                    }
648                    Err(error) => Some(Err(anyhow!(error))),
649                }
650            })
651            .boxed())
652    } else {
653        let mut body = Vec::new();
654        response.body_mut().read_to_end(&mut body).await?;
655        let body_str = std::str::from_utf8(&body)?;
656        let response: ResponseEvent = serde_json::from_str(body_str)?;
657
658        Ok(futures::stream::once(async move { Ok(response) }).boxed())
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_resilient_model_schema_deserialize() {
668        let json = r#"{
669              "data": [
670                {
671                  "capabilities": {
672                    "family": "gpt-4",
673                    "limits": {
674                      "max_context_window_tokens": 32768,
675                      "max_output_tokens": 4096,
676                      "max_prompt_tokens": 32768
677                    },
678                    "object": "model_capabilities",
679                    "supports": { "streaming": true, "tool_calls": true },
680                    "tokenizer": "cl100k_base",
681                    "type": "chat"
682                  },
683                  "id": "gpt-4",
684                  "model_picker_enabled": false,
685                  "name": "GPT 4",
686                  "object": "model",
687                  "preview": false,
688                  "vendor": "Azure OpenAI",
689                  "version": "gpt-4-0613"
690                },
691                {
692                    "some-unknown-field": 123
693                },
694                {
695                  "capabilities": {
696                    "family": "claude-3.7-sonnet",
697                    "limits": {
698                      "max_context_window_tokens": 200000,
699                      "max_output_tokens": 16384,
700                      "max_prompt_tokens": 90000,
701                      "vision": {
702                        "max_prompt_image_size": 3145728,
703                        "max_prompt_images": 1,
704                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
705                      }
706                    },
707                    "object": "model_capabilities",
708                    "supports": {
709                      "parallel_tool_calls": true,
710                      "streaming": true,
711                      "tool_calls": true,
712                      "vision": true
713                    },
714                    "tokenizer": "o200k_base",
715                    "type": "chat"
716                  },
717                  "id": "claude-3.7-sonnet",
718                  "model_picker_enabled": true,
719                  "name": "Claude 3.7 Sonnet",
720                  "object": "model",
721                  "policy": {
722                    "state": "enabled",
723                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
724                  },
725                  "preview": false,
726                  "vendor": "Anthropic",
727                  "version": "claude-3.7-sonnet"
728                }
729              ],
730              "object": "list"
731            }"#;
732
733        let schema: ModelSchema = serde_json::from_str(&json).unwrap();
734
735        assert_eq!(schema.data.len(), 2);
736        assert_eq!(schema.data[0].id, "gpt-4");
737        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
738    }
739}