1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::Context as _;
6use anyhow::{Result, anyhow};
7use chrono::DateTime;
8use collections::HashSet;
9use fs::Fs;
10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
11use gpui::WeakEntity;
12use gpui::{App, AsyncApp, Global, prelude::*};
13use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
14use itertools::Itertools;
15use paths::home_dir;
16use serde::{Deserialize, Serialize};
17use settings::watch_config_dir;
18
19#[derive(Default, Clone, Debug, PartialEq)]
20pub struct CopilotChatSettings {
21 pub api_url: Arc<str>,
22 pub auth_url: Arc<str>,
23 pub models_url: Arc<str>,
24}
25
26// Copilot's base model; defined by Microsoft in premium requests table
27// This will be moved to the front of the Copilot model list, and will be used for
28// 'fast' requests (e.g. title generation)
29// https://docs.github.com/en/copilot/managing-copilot/monitoring-usage-and-entitlements/about-premium-requests
30const DEFAULT_MODEL_ID: &str = "gpt-4.1";
31
32#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
33#[serde(rename_all = "lowercase")]
34pub enum Role {
35 User,
36 Assistant,
37 System,
38}
39
40#[derive(Deserialize)]
41struct ModelSchema {
42 #[serde(deserialize_with = "deserialize_models_skip_errors")]
43 data: Vec<Model>,
44}
45
46fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
47where
48 D: serde::Deserializer<'de>,
49{
50 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
51 let models = raw_values
52 .into_iter()
53 .filter_map(|value| match serde_json::from_value::<Model>(value) {
54 Ok(model) => Some(model),
55 Err(err) => {
56 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
57 None
58 }
59 })
60 .collect();
61
62 Ok(models)
63}
64
65#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
66pub struct Model {
67 capabilities: ModelCapabilities,
68 id: String,
69 name: String,
70 policy: Option<ModelPolicy>,
71 vendor: ModelVendor,
72 model_picker_enabled: bool,
73}
74
75#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
76struct ModelCapabilities {
77 family: String,
78 #[serde(default)]
79 limits: ModelLimits,
80 supports: ModelSupportedFeatures,
81}
82
83#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
84struct ModelLimits {
85 #[serde(default)]
86 max_context_window_tokens: usize,
87 #[serde(default)]
88 max_output_tokens: usize,
89 #[serde(default)]
90 max_prompt_tokens: usize,
91}
92
93#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
94struct ModelPolicy {
95 state: String,
96}
97
98#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
99struct ModelSupportedFeatures {
100 #[serde(default)]
101 streaming: bool,
102 #[serde(default)]
103 tool_calls: bool,
104 #[serde(default)]
105 parallel_tool_calls: bool,
106 #[serde(default)]
107 vision: bool,
108}
109
110#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
111pub enum ModelVendor {
112 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
113 #[serde(alias = "Azure OpenAI")]
114 OpenAI,
115 Google,
116 Anthropic,
117}
118
119#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
120#[serde(tag = "type")]
121pub enum ChatMessagePart {
122 #[serde(rename = "text")]
123 Text { text: String },
124 #[serde(rename = "image_url")]
125 Image { image_url: ImageUrl },
126}
127
128#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
129pub struct ImageUrl {
130 pub url: String,
131}
132
133impl Model {
134 pub fn uses_streaming(&self) -> bool {
135 self.capabilities.supports.streaming
136 }
137
138 pub fn id(&self) -> &str {
139 self.id.as_str()
140 }
141
142 pub fn display_name(&self) -> &str {
143 self.name.as_str()
144 }
145
146 pub fn max_token_count(&self) -> usize {
147 self.capabilities.limits.max_prompt_tokens
148 }
149
150 pub fn supports_tools(&self) -> bool {
151 self.capabilities.supports.tool_calls
152 }
153
154 pub fn vendor(&self) -> ModelVendor {
155 self.vendor
156 }
157
158 pub fn supports_vision(&self) -> bool {
159 self.capabilities.supports.vision
160 }
161
162 pub fn supports_parallel_tool_calls(&self) -> bool {
163 self.capabilities.supports.parallel_tool_calls
164 }
165}
166
167#[derive(Serialize, Deserialize)]
168pub struct Request {
169 pub intent: bool,
170 pub n: usize,
171 pub stream: bool,
172 pub temperature: f32,
173 pub model: String,
174 pub messages: Vec<ChatMessage>,
175 #[serde(default, skip_serializing_if = "Vec::is_empty")]
176 pub tools: Vec<Tool>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub tool_choice: Option<ToolChoice>,
179}
180
181#[derive(Serialize, Deserialize)]
182pub struct Function {
183 pub name: String,
184 pub description: String,
185 pub parameters: serde_json::Value,
186}
187
188#[derive(Serialize, Deserialize)]
189#[serde(tag = "type", rename_all = "snake_case")]
190pub enum Tool {
191 Function { function: Function },
192}
193
194#[derive(Serialize, Deserialize)]
195#[serde(rename_all = "lowercase")]
196pub enum ToolChoice {
197 Auto,
198 Any,
199 None,
200}
201
202#[derive(Serialize, Deserialize, Debug)]
203#[serde(tag = "role", rename_all = "lowercase")]
204pub enum ChatMessage {
205 Assistant {
206 content: ChatMessageContent,
207 #[serde(default, skip_serializing_if = "Vec::is_empty")]
208 tool_calls: Vec<ToolCall>,
209 },
210 User {
211 content: ChatMessageContent,
212 },
213 System {
214 content: String,
215 },
216 Tool {
217 content: ChatMessageContent,
218 tool_call_id: String,
219 },
220}
221
222#[derive(Debug, Serialize, Deserialize)]
223#[serde(untagged)]
224pub enum ChatMessageContent {
225 Plain(String),
226 Multipart(Vec<ChatMessagePart>),
227}
228
229impl ChatMessageContent {
230 pub fn empty() -> Self {
231 ChatMessageContent::Multipart(vec![])
232 }
233}
234
235impl From<Vec<ChatMessagePart>> for ChatMessageContent {
236 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
237 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
238 ChatMessageContent::Plain(std::mem::take(text))
239 } else {
240 ChatMessageContent::Multipart(parts)
241 }
242 }
243}
244
245impl From<String> for ChatMessageContent {
246 fn from(text: String) -> Self {
247 ChatMessageContent::Plain(text)
248 }
249}
250
251#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
252pub struct ToolCall {
253 pub id: String,
254 #[serde(flatten)]
255 pub content: ToolCallContent,
256}
257
258#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
259#[serde(tag = "type", rename_all = "lowercase")]
260pub enum ToolCallContent {
261 Function { function: FunctionContent },
262}
263
264#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
265pub struct FunctionContent {
266 pub name: String,
267 pub arguments: String,
268}
269
270#[derive(Deserialize, Debug)]
271#[serde(tag = "type", rename_all = "snake_case")]
272pub struct ResponseEvent {
273 pub choices: Vec<ResponseChoice>,
274 pub id: String,
275}
276
277#[derive(Debug, Deserialize)]
278pub struct ResponseChoice {
279 pub index: usize,
280 pub finish_reason: Option<String>,
281 pub delta: Option<ResponseDelta>,
282 pub message: Option<ResponseDelta>,
283}
284
285#[derive(Debug, Deserialize)]
286pub struct ResponseDelta {
287 pub content: Option<String>,
288 pub role: Option<Role>,
289 #[serde(default)]
290 pub tool_calls: Vec<ToolCallChunk>,
291}
292
293#[derive(Deserialize, Debug, Eq, PartialEq)]
294pub struct ToolCallChunk {
295 pub index: usize,
296 pub id: Option<String>,
297 pub function: Option<FunctionChunk>,
298}
299
300#[derive(Deserialize, Debug, Eq, PartialEq)]
301pub struct FunctionChunk {
302 pub name: Option<String>,
303 pub arguments: Option<String>,
304}
305
306#[derive(Deserialize)]
307struct ApiTokenResponse {
308 token: String,
309 expires_at: i64,
310}
311
312#[derive(Clone)]
313struct ApiToken {
314 api_key: String,
315 expires_at: DateTime<chrono::Utc>,
316}
317
318impl ApiToken {
319 pub fn remaining_seconds(&self) -> i64 {
320 self.expires_at
321 .timestamp()
322 .saturating_sub(chrono::Utc::now().timestamp())
323 }
324}
325
326impl TryFrom<ApiTokenResponse> for ApiToken {
327 type Error = anyhow::Error;
328
329 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
330 let expires_at =
331 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
332
333 Ok(Self {
334 api_key: response.token,
335 expires_at,
336 })
337 }
338}
339
340struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
341
342impl Global for GlobalCopilotChat {}
343
344pub struct CopilotChat {
345 oauth_token: Option<String>,
346 api_token: Option<ApiToken>,
347 settings: CopilotChatSettings,
348 models: Option<Vec<Model>>,
349 client: Arc<dyn HttpClient>,
350}
351
352pub fn init(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut App) {
353 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, cx));
354 cx.set_global(GlobalCopilotChat(copilot_chat));
355}
356
357pub fn copilot_chat_config_dir() -> &'static PathBuf {
358 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
359
360 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
361 if cfg!(target_os = "windows") {
362 home_dir().join("AppData").join("Local")
363 } else {
364 home_dir().join(".config")
365 }
366 .join("github-copilot")
367 })
368}
369
370fn copilot_chat_config_paths() -> [PathBuf; 2] {
371 let base_dir = copilot_chat_config_dir();
372 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
373}
374
375impl CopilotChat {
376 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
377 cx.try_global::<GlobalCopilotChat>()
378 .map(|model| model.0.clone())
379 }
380
381 fn new(fs: Arc<dyn Fs>, client: Arc<dyn HttpClient>, cx: &mut Context<Self>) -> Self {
382 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
383 let dir_path = copilot_chat_config_dir();
384 let settings = CopilotChatSettings::default();
385 cx.spawn(async move |this, cx| {
386 let mut parent_watch_rx = watch_config_dir(
387 cx.background_executor(),
388 fs.clone(),
389 dir_path.clone(),
390 config_paths,
391 );
392 while let Some(contents) = parent_watch_rx.next().await {
393 let oauth_token = extract_oauth_token(contents);
394
395 this.update(cx, |this, cx| {
396 this.oauth_token = oauth_token.clone();
397 cx.notify();
398 })?;
399
400 if oauth_token.is_some() {
401 Self::update_models(&this, cx).await?;
402 }
403 }
404 anyhow::Ok(())
405 })
406 .detach_and_log_err(cx);
407
408 Self {
409 oauth_token: None,
410 api_token: None,
411 models: None,
412 settings,
413 client,
414 }
415 }
416
417 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
418 let (oauth_token, client, auth_url) = this.read_with(cx, |this, _| {
419 (
420 this.oauth_token.clone(),
421 this.client.clone(),
422 this.settings.auth_url.clone(),
423 )
424 })?;
425 let api_token = request_api_token(
426 &oauth_token.ok_or_else(|| {
427 anyhow!("OAuth token is missing while updating Copilot Chat models")
428 })?,
429 auth_url,
430 client.clone(),
431 )
432 .await?;
433
434 let models_url = this.update(cx, |this, cx| {
435 this.api_token = Some(api_token.clone());
436 cx.notify();
437 this.settings.models_url.clone()
438 })?;
439 let models = get_models(models_url, api_token.api_key, client.clone()).await?;
440
441 this.update(cx, |this, cx| {
442 this.models = Some(models);
443 cx.notify();
444 })?;
445 anyhow::Ok(())
446 }
447
448 pub fn is_authenticated(&self) -> bool {
449 self.oauth_token.is_some()
450 }
451
452 pub fn models(&self) -> Option<&[Model]> {
453 self.models.as_deref()
454 }
455
456 pub async fn stream_completion(
457 request: Request,
458 mut cx: AsyncApp,
459 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
460 let this = cx
461 .update(|cx| Self::global(cx))
462 .ok()
463 .flatten()
464 .context("Copilot chat is not enabled")?;
465
466 let (oauth_token, api_token, client, api_url, auth_url) =
467 this.read_with(&cx, |this, _| {
468 (
469 this.oauth_token.clone(),
470 this.api_token.clone(),
471 this.client.clone(),
472 this.settings.api_url.clone(),
473 this.settings.auth_url.clone(),
474 )
475 })?;
476
477 let oauth_token = oauth_token.context("No OAuth token available")?;
478
479 let token = match api_token {
480 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
481 _ => {
482 let token = request_api_token(&oauth_token, auth_url, client.clone()).await?;
483 this.update(&mut cx, |this, cx| {
484 this.api_token = Some(token.clone());
485 cx.notify();
486 })?;
487 token
488 }
489 };
490
491 stream_completion(client.clone(), token.api_key, api_url, request).await
492 }
493
494 pub fn set_settings(&mut self, settings: CopilotChatSettings, cx: &mut Context<Self>) {
495 let same_settings = self.settings == settings;
496 self.settings = settings;
497 if !same_settings {
498 cx.spawn(async move |this, cx| {
499 Self::update_models(&this, cx).await?;
500 Ok::<_, anyhow::Error>(())
501 })
502 .detach();
503 }
504 }
505}
506
507async fn get_models(
508 models_url: Arc<str>,
509 api_token: String,
510 client: Arc<dyn HttpClient>,
511) -> Result<Vec<Model>> {
512 let all_models = request_models(models_url, api_token, client).await?;
513
514 let mut models: Vec<Model> = all_models
515 .into_iter()
516 .filter(|model| {
517 // Ensure user has access to the model; Policy is present only for models that must be
518 // enabled in the GitHub dashboard
519 model.model_picker_enabled
520 && model
521 .policy
522 .as_ref()
523 .is_none_or(|policy| policy.state == "enabled")
524 })
525 // The first model from the API response, in any given family, appear to be the non-tagged
526 // models, which are likely the best choice (e.g. gpt-4o rather than gpt-4o-2024-11-20)
527 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
528 .collect();
529
530 if let Some(default_model_position) =
531 models.iter().position(|model| model.id == DEFAULT_MODEL_ID)
532 {
533 let default_model = models.remove(default_model_position);
534 models.insert(0, default_model);
535 }
536
537 Ok(models)
538}
539
540async fn request_models(
541 models_url: Arc<str>,
542 api_token: String,
543 client: Arc<dyn HttpClient>,
544) -> Result<Vec<Model>> {
545 let request_builder = HttpRequest::builder()
546 .method(Method::GET)
547 .uri(models_url.as_ref())
548 .header("Authorization", format!("Bearer {}", api_token))
549 .header("Content-Type", "application/json")
550 .header("Copilot-Integration-Id", "vscode-chat");
551
552 let request = request_builder.body(AsyncBody::empty())?;
553
554 let mut response = client.send(request).await?;
555
556 anyhow::ensure!(
557 response.status().is_success(),
558 "Failed to request models: {}",
559 response.status()
560 );
561 let mut body = Vec::new();
562 response.body_mut().read_to_end(&mut body).await?;
563
564 let body_str = std::str::from_utf8(&body)?;
565
566 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
567
568 Ok(models)
569}
570
571async fn request_api_token(
572 oauth_token: &str,
573 auth_url: Arc<str>,
574 client: Arc<dyn HttpClient>,
575) -> Result<ApiToken> {
576 let request_builder = HttpRequest::builder()
577 .method(Method::GET)
578 .uri(auth_url.as_ref())
579 .header("Authorization", format!("token {}", oauth_token))
580 .header("Accept", "application/json");
581
582 let request = request_builder.body(AsyncBody::empty())?;
583
584 let mut response = client.send(request).await?;
585
586 if response.status().is_success() {
587 let mut body = Vec::new();
588 response.body_mut().read_to_end(&mut body).await?;
589
590 let body_str = std::str::from_utf8(&body)?;
591
592 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
593 ApiToken::try_from(parsed)
594 } else {
595 let mut body = Vec::new();
596 response.body_mut().read_to_end(&mut body).await?;
597
598 let body_str = std::str::from_utf8(&body)?;
599 anyhow::bail!("Failed to request API token: {body_str}");
600 }
601}
602
603fn extract_oauth_token(contents: String) -> Option<String> {
604 serde_json::from_str::<serde_json::Value>(&contents)
605 .map(|v| {
606 v.as_object().and_then(|obj| {
607 obj.iter().find_map(|(key, value)| {
608 if key.starts_with("github.com") {
609 value["oauth_token"].as_str().map(|v| v.to_string())
610 } else {
611 None
612 }
613 })
614 })
615 })
616 .ok()
617 .flatten()
618}
619
620async fn stream_completion(
621 client: Arc<dyn HttpClient>,
622 api_key: String,
623 completion_url: Arc<str>,
624 request: Request,
625) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
626 let is_vision_request = request.messages.last().map_or(false, |message| match message {
627 ChatMessage::User { content }
628 | ChatMessage::Assistant { content, .. }
629 | ChatMessage::Tool { content, .. } => {
630 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
631 }
632 _ => false,
633 });
634
635 let request_builder = HttpRequest::builder()
636 .method(Method::POST)
637 .uri(completion_url.as_ref())
638 .header(
639 "Editor-Version",
640 format!(
641 "Zed/{}",
642 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
643 ),
644 )
645 .header("Authorization", format!("Bearer {}", api_key))
646 .header("Content-Type", "application/json")
647 .header("Copilot-Integration-Id", "vscode-chat")
648 .header("Copilot-Vision-Request", is_vision_request.to_string());
649
650 let is_streaming = request.stream;
651
652 let json = serde_json::to_string(&request)?;
653 let request = request_builder.body(AsyncBody::from(json))?;
654 let mut response = client.send(request).await?;
655
656 if !response.status().is_success() {
657 let mut body = Vec::new();
658 response.body_mut().read_to_end(&mut body).await?;
659 let body_str = std::str::from_utf8(&body)?;
660 anyhow::bail!(
661 "Failed to connect to API: {} {}",
662 response.status(),
663 body_str
664 );
665 }
666
667 if is_streaming {
668 let reader = BufReader::new(response.into_body());
669 Ok(reader
670 .lines()
671 .filter_map(|line| async move {
672 match line {
673 Ok(line) => {
674 let line = line.strip_prefix("data: ")?;
675 if line.starts_with("[DONE]") {
676 return None;
677 }
678
679 match serde_json::from_str::<ResponseEvent>(line) {
680 Ok(response) => {
681 if response.choices.is_empty() {
682 None
683 } else {
684 Some(Ok(response))
685 }
686 }
687 Err(error) => Some(Err(anyhow!(error))),
688 }
689 }
690 Err(error) => Some(Err(anyhow!(error))),
691 }
692 })
693 .boxed())
694 } else {
695 let mut body = Vec::new();
696 response.body_mut().read_to_end(&mut body).await?;
697 let body_str = std::str::from_utf8(&body)?;
698 let response: ResponseEvent = serde_json::from_str(body_str)?;
699
700 Ok(futures::stream::once(async move { Ok(response) }).boxed())
701 }
702}
703
704#[cfg(test)]
705mod tests {
706 use super::*;
707
708 #[test]
709 fn test_resilient_model_schema_deserialize() {
710 let json = r#"{
711 "data": [
712 {
713 "capabilities": {
714 "family": "gpt-4",
715 "limits": {
716 "max_context_window_tokens": 32768,
717 "max_output_tokens": 4096,
718 "max_prompt_tokens": 32768
719 },
720 "object": "model_capabilities",
721 "supports": { "streaming": true, "tool_calls": true },
722 "tokenizer": "cl100k_base",
723 "type": "chat"
724 },
725 "id": "gpt-4",
726 "model_picker_enabled": false,
727 "name": "GPT 4",
728 "object": "model",
729 "preview": false,
730 "vendor": "Azure OpenAI",
731 "version": "gpt-4-0613"
732 },
733 {
734 "some-unknown-field": 123
735 },
736 {
737 "capabilities": {
738 "family": "claude-3.7-sonnet",
739 "limits": {
740 "max_context_window_tokens": 200000,
741 "max_output_tokens": 16384,
742 "max_prompt_tokens": 90000,
743 "vision": {
744 "max_prompt_image_size": 3145728,
745 "max_prompt_images": 1,
746 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
747 }
748 },
749 "object": "model_capabilities",
750 "supports": {
751 "parallel_tool_calls": true,
752 "streaming": true,
753 "tool_calls": true,
754 "vision": true
755 },
756 "tokenizer": "o200k_base",
757 "type": "chat"
758 },
759 "id": "claude-3.7-sonnet",
760 "model_picker_enabled": true,
761 "name": "Claude 3.7 Sonnet",
762 "object": "model",
763 "policy": {
764 "state": "enabled",
765 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
766 },
767 "preview": false,
768 "vendor": "Anthropic",
769 "version": "claude-3.7-sonnet"
770 }
771 ],
772 "object": "list"
773 }"#;
774
775 let schema: ModelSchema = serde_json::from_str(&json).unwrap();
776
777 assert_eq!(schema.data.len(), 2);
778 assert_eq!(schema.data[0].id, "gpt-4");
779 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
780 }
781}