1pub mod responses;
2
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::sync::OnceLock;
6
7use anyhow::Context as _;
8use anyhow::{Result, anyhow};
9use chrono::DateTime;
10use collections::HashSet;
11use fs::Fs;
12use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
13use gpui::WeakEntity;
14use gpui::{App, AsyncApp, Global, prelude::*};
15use http_client::HttpRequestExt;
16use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
17use paths::home_dir;
18use serde::{Deserialize, Serialize};
19
20use settings::watch_config_dir;
21
22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn token_url(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 let domain = Self::parse_domain(enterprise_uri);
33 format!("https://api.{}/copilot_internal/v2/token", domain)
34 } else {
35 "https://api.github.com/copilot_internal/v2/token".to_string()
36 }
37 }
38
39 pub fn oauth_domain(&self) -> String {
40 if let Some(enterprise_uri) = &self.enterprise_uri {
41 Self::parse_domain(enterprise_uri)
42 } else {
43 "github.com".to_string()
44 }
45 }
46
47 pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
48 format!("{}/chat/completions", endpoint)
49 }
50
51 pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
52 format!("{}/responses", endpoint)
53 }
54
55 pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
56 format!("{}/models", endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86 #[serde(rename = "/v1/messages")]
87 Messages,
88 /// Unknown endpoint that we don't explicitly support yet
89 #[serde(other)]
90 Unknown,
91}
92
93#[derive(Deserialize)]
94struct ModelSchema {
95 #[serde(deserialize_with = "deserialize_models_skip_errors")]
96 data: Vec<Model>,
97}
98
99fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
100where
101 D: serde::Deserializer<'de>,
102{
103 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
104 let models = raw_values
105 .into_iter()
106 .filter_map(|value| match serde_json::from_value::<Model>(value) {
107 Ok(model) => Some(model),
108 Err(err) => {
109 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
110 None
111 }
112 })
113 .collect();
114
115 Ok(models)
116}
117
118#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
119pub struct Model {
120 billing: ModelBilling,
121 capabilities: ModelCapabilities,
122 id: String,
123 name: String,
124 policy: Option<ModelPolicy>,
125 vendor: ModelVendor,
126 is_chat_default: bool,
127 // The model with this value true is selected by VSCode copilot if a premium request limit is
128 // reached. Zed does not currently implement this behaviour
129 is_chat_fallback: bool,
130 model_picker_enabled: bool,
131 #[serde(default)]
132 supported_endpoints: Vec<ModelSupportedEndpoint>,
133}
134
135#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
136struct ModelBilling {
137 is_premium: bool,
138 multiplier: f64,
139 // List of plans a model is restricted to
140 // Field is not present if a model is available for all plans
141 #[serde(default)]
142 restricted_to: Option<Vec<String>>,
143}
144
145#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
146struct ModelCapabilities {
147 family: String,
148 #[serde(default)]
149 limits: ModelLimits,
150 supports: ModelSupportedFeatures,
151 #[serde(rename = "type")]
152 model_type: String,
153 #[serde(default)]
154 tokenizer: Option<String>,
155}
156
157#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
158struct ModelLimits {
159 #[serde(default)]
160 max_context_window_tokens: usize,
161 #[serde(default)]
162 max_output_tokens: usize,
163 #[serde(default)]
164 max_prompt_tokens: u64,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelPolicy {
169 state: String,
170}
171
172#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
173struct ModelSupportedFeatures {
174 #[serde(default)]
175 streaming: bool,
176 #[serde(default)]
177 tool_calls: bool,
178 #[serde(default)]
179 parallel_tool_calls: bool,
180 #[serde(default)]
181 vision: bool,
182}
183
184#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
185pub enum ModelVendor {
186 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
187 #[serde(alias = "Azure OpenAI")]
188 OpenAI,
189 Google,
190 Anthropic,
191 #[serde(rename = "xAI")]
192 XAI,
193 /// Unknown vendor that we don't explicitly support yet
194 #[serde(other)]
195 Unknown,
196}
197
198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
199#[serde(tag = "type")]
200pub enum ChatMessagePart {
201 #[serde(rename = "text")]
202 Text { text: String },
203 #[serde(rename = "image_url")]
204 Image { image_url: ImageUrl },
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
208pub struct ImageUrl {
209 pub url: String,
210}
211
212impl Model {
213 pub fn uses_streaming(&self) -> bool {
214 self.capabilities.supports.streaming
215 }
216
217 pub fn id(&self) -> &str {
218 self.id.as_str()
219 }
220
221 pub fn display_name(&self) -> &str {
222 self.name.as_str()
223 }
224
225 pub fn max_token_count(&self) -> u64 {
226 self.capabilities.limits.max_prompt_tokens
227 }
228
229 pub fn supports_tools(&self) -> bool {
230 self.capabilities.supports.tool_calls
231 }
232
233 pub fn vendor(&self) -> ModelVendor {
234 self.vendor
235 }
236
237 pub fn supports_vision(&self) -> bool {
238 self.capabilities.supports.vision
239 }
240
241 pub fn supports_parallel_tool_calls(&self) -> bool {
242 self.capabilities.supports.parallel_tool_calls
243 }
244
245 pub fn tokenizer(&self) -> Option<&str> {
246 self.capabilities.tokenizer.as_deref()
247 }
248
249 pub fn supports_response(&self) -> bool {
250 self.supported_endpoints.len() > 0
251 && !self
252 .supported_endpoints
253 .contains(&ModelSupportedEndpoint::ChatCompletions)
254 && self
255 .supported_endpoints
256 .contains(&ModelSupportedEndpoint::Responses)
257 }
258}
259
260#[derive(Serialize, Deserialize)]
261pub struct Request {
262 pub intent: bool,
263 pub n: usize,
264 pub stream: bool,
265 pub temperature: f32,
266 pub model: String,
267 pub messages: Vec<ChatMessage>,
268 #[serde(default, skip_serializing_if = "Vec::is_empty")]
269 pub tools: Vec<Tool>,
270 #[serde(default, skip_serializing_if = "Option::is_none")]
271 pub tool_choice: Option<ToolChoice>,
272}
273
274#[derive(Serialize, Deserialize)]
275pub struct Function {
276 pub name: String,
277 pub description: String,
278 pub parameters: serde_json::Value,
279}
280
281#[derive(Serialize, Deserialize)]
282#[serde(tag = "type", rename_all = "snake_case")]
283pub enum Tool {
284 Function { function: Function },
285}
286
287#[derive(Serialize, Deserialize, Debug)]
288#[serde(rename_all = "lowercase")]
289pub enum ToolChoice {
290 Auto,
291 Any,
292 None,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296#[serde(tag = "role", rename_all = "lowercase")]
297pub enum ChatMessage {
298 Assistant {
299 content: ChatMessageContent,
300 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 tool_calls: Vec<ToolCall>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 reasoning_opaque: Option<String>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 reasoning_text: Option<String>,
306 },
307 User {
308 content: ChatMessageContent,
309 },
310 System {
311 content: String,
312 },
313 Tool {
314 content: ChatMessageContent,
315 tool_call_id: String,
316 },
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(untagged)]
321pub enum ChatMessageContent {
322 Plain(String),
323 Multipart(Vec<ChatMessagePart>),
324}
325
326impl ChatMessageContent {
327 pub fn empty() -> Self {
328 ChatMessageContent::Multipart(vec![])
329 }
330}
331
332impl From<Vec<ChatMessagePart>> for ChatMessageContent {
333 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
334 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
335 ChatMessageContent::Plain(std::mem::take(text))
336 } else {
337 ChatMessageContent::Multipart(parts)
338 }
339 }
340}
341
342impl From<String> for ChatMessageContent {
343 fn from(text: String) -> Self {
344 ChatMessageContent::Plain(text)
345 }
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct ToolCall {
350 pub id: String,
351 #[serde(flatten)]
352 pub content: ToolCallContent,
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356#[serde(tag = "type", rename_all = "lowercase")]
357pub enum ToolCallContent {
358 Function { function: FunctionContent },
359}
360
361#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
362pub struct FunctionContent {
363 pub name: String,
364 pub arguments: String,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub thought_signature: Option<String>,
367}
368
369#[derive(Deserialize, Debug)]
370#[serde(tag = "type", rename_all = "snake_case")]
371pub struct ResponseEvent {
372 pub choices: Vec<ResponseChoice>,
373 pub id: String,
374 pub usage: Option<Usage>,
375}
376
377#[derive(Deserialize, Debug)]
378pub struct Usage {
379 pub completion_tokens: u64,
380 pub prompt_tokens: u64,
381 pub total_tokens: u64,
382}
383
384#[derive(Debug, Deserialize)]
385pub struct ResponseChoice {
386 pub index: Option<usize>,
387 pub finish_reason: Option<String>,
388 pub delta: Option<ResponseDelta>,
389 pub message: Option<ResponseDelta>,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct ResponseDelta {
394 pub content: Option<String>,
395 pub role: Option<Role>,
396 #[serde(default)]
397 pub tool_calls: Vec<ToolCallChunk>,
398 pub reasoning_opaque: Option<String>,
399 pub reasoning_text: Option<String>,
400}
401#[derive(Deserialize, Debug, Eq, PartialEq)]
402pub struct ToolCallChunk {
403 pub index: Option<usize>,
404 pub id: Option<String>,
405 pub function: Option<FunctionChunk>,
406}
407
408#[derive(Deserialize, Debug, Eq, PartialEq)]
409pub struct FunctionChunk {
410 pub name: Option<String>,
411 pub arguments: Option<String>,
412 pub thought_signature: Option<String>,
413}
414
415#[derive(Deserialize)]
416struct ApiTokenResponse {
417 token: String,
418 expires_at: i64,
419 endpoints: ApiTokenResponseEndpoints,
420}
421
422#[derive(Deserialize)]
423struct ApiTokenResponseEndpoints {
424 api: String,
425}
426
427#[derive(Clone)]
428struct ApiToken {
429 api_key: String,
430 expires_at: DateTime<chrono::Utc>,
431 api_endpoint: String,
432}
433
434impl ApiToken {
435 pub fn remaining_seconds(&self) -> i64 {
436 self.expires_at
437 .timestamp()
438 .saturating_sub(chrono::Utc::now().timestamp())
439 }
440}
441
442impl TryFrom<ApiTokenResponse> for ApiToken {
443 type Error = anyhow::Error;
444
445 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
446 let expires_at =
447 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
448
449 Ok(Self {
450 api_key: response.token,
451 expires_at,
452 api_endpoint: response.endpoints.api,
453 })
454 }
455}
456
457struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
458
459impl Global for GlobalCopilotChat {}
460
461pub struct CopilotChat {
462 oauth_token: Option<String>,
463 api_token: Option<ApiToken>,
464 configuration: CopilotChatConfiguration,
465 models: Option<Vec<Model>>,
466 client: Arc<dyn HttpClient>,
467}
468
469pub fn init(
470 fs: Arc<dyn Fs>,
471 client: Arc<dyn HttpClient>,
472 configuration: CopilotChatConfiguration,
473 cx: &mut App,
474) {
475 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
476 cx.set_global(GlobalCopilotChat(copilot_chat));
477}
478
479pub fn copilot_chat_config_dir() -> &'static PathBuf {
480 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
481
482 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
483 let config_dir = if cfg!(target_os = "windows") {
484 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
485 } else {
486 std::env::var("XDG_CONFIG_HOME")
487 .map(PathBuf::from)
488 .unwrap_or_else(|_| home_dir().join(".config"))
489 };
490
491 config_dir.join("github-copilot")
492 })
493}
494
495fn copilot_chat_config_paths() -> [PathBuf; 2] {
496 let base_dir = copilot_chat_config_dir();
497 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
498}
499
500impl CopilotChat {
501 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
502 cx.try_global::<GlobalCopilotChat>()
503 .map(|model| model.0.clone())
504 }
505
506 fn new(
507 fs: Arc<dyn Fs>,
508 client: Arc<dyn HttpClient>,
509 configuration: CopilotChatConfiguration,
510 cx: &mut Context<Self>,
511 ) -> Self {
512 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
513 let dir_path = copilot_chat_config_dir();
514
515 cx.spawn(async move |this, cx| {
516 let mut parent_watch_rx = watch_config_dir(
517 cx.background_executor(),
518 fs.clone(),
519 dir_path.clone(),
520 config_paths,
521 );
522 while let Some(contents) = parent_watch_rx.next().await {
523 let oauth_domain =
524 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
525 let oauth_token = extract_oauth_token(contents, &oauth_domain);
526
527 this.update(cx, |this, cx| {
528 this.oauth_token = oauth_token.clone();
529 cx.notify();
530 })?;
531
532 if oauth_token.is_some() {
533 Self::update_models(&this, cx).await?;
534 }
535 }
536 anyhow::Ok(())
537 })
538 .detach_and_log_err(cx);
539
540 let this = Self {
541 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
542 api_token: None,
543 models: None,
544 configuration,
545 client,
546 };
547
548 if this.oauth_token.is_some() {
549 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
550 .detach_and_log_err(cx);
551 }
552
553 this
554 }
555
556 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
557 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
558 (
559 this.oauth_token.clone(),
560 this.client.clone(),
561 this.configuration.clone(),
562 )
563 })?;
564
565 let oauth_token = oauth_token
566 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
567
568 let token_url = configuration.token_url();
569 let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
570
571 let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
572 let models =
573 get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
574
575 this.update(cx, |this, cx| {
576 this.api_token = Some(api_token);
577 this.models = Some(models);
578 cx.notify();
579 })?;
580 anyhow::Ok(())
581 }
582
583 pub fn is_authenticated(&self) -> bool {
584 self.oauth_token.is_some()
585 }
586
587 pub fn models(&self) -> Option<&[Model]> {
588 self.models.as_deref()
589 }
590
591 pub async fn stream_completion(
592 request: Request,
593 is_user_initiated: bool,
594 mut cx: AsyncApp,
595 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
596 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
597
598 let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
599 stream_completion(
600 client.clone(),
601 token.api_key,
602 api_url.into(),
603 request,
604 is_user_initiated,
605 )
606 .await
607 }
608
609 pub async fn stream_response(
610 request: responses::Request,
611 is_user_initiated: bool,
612 mut cx: AsyncApp,
613 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
614 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
615
616 let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
617 responses::stream_response(
618 client.clone(),
619 token.api_key,
620 api_url,
621 request,
622 is_user_initiated,
623 )
624 .await
625 }
626
627 async fn get_auth_details(
628 cx: &mut AsyncApp,
629 ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
630 let this = cx
631 .update(|cx| Self::global(cx))
632 .context("Copilot chat is not enabled")?;
633
634 let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
635 (
636 this.oauth_token.clone(),
637 this.api_token.clone(),
638 this.client.clone(),
639 this.configuration.clone(),
640 )
641 });
642
643 let oauth_token = oauth_token.context("No OAuth token available")?;
644
645 let token = match api_token {
646 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
647 _ => {
648 let token_url = configuration.token_url();
649 let token =
650 request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
651 this.update(cx, |this, cx| {
652 this.api_token = Some(token.clone());
653 cx.notify();
654 });
655 token
656 }
657 };
658
659 Ok((client, token, configuration))
660 }
661
662 pub fn set_configuration(
663 &mut self,
664 configuration: CopilotChatConfiguration,
665 cx: &mut Context<Self>,
666 ) {
667 let same_configuration = self.configuration == configuration;
668 self.configuration = configuration;
669 if !same_configuration {
670 self.api_token = None;
671 cx.spawn(async move |this, cx| {
672 Self::update_models(&this, cx).await?;
673 Ok::<_, anyhow::Error>(())
674 })
675 .detach();
676 }
677 }
678}
679
680async fn get_models(
681 models_url: Arc<str>,
682 api_token: String,
683 client: Arc<dyn HttpClient>,
684) -> Result<Vec<Model>> {
685 let all_models = request_models(models_url, api_token, client).await?;
686
687 let mut models: Vec<Model> = all_models
688 .into_iter()
689 .filter(|model| {
690 model.model_picker_enabled
691 && model.capabilities.model_type.as_str() == "chat"
692 && model
693 .policy
694 .as_ref()
695 .is_none_or(|policy| policy.state == "enabled")
696 })
697 .collect();
698
699 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
700 let default_model = models.remove(default_model_position);
701 models.insert(0, default_model);
702 }
703
704 Ok(models)
705}
706
707async fn request_models(
708 models_url: Arc<str>,
709 api_token: String,
710 client: Arc<dyn HttpClient>,
711) -> Result<Vec<Model>> {
712 let request_builder = HttpRequest::builder()
713 .method(Method::GET)
714 .uri(models_url.as_ref())
715 .header("Authorization", format!("Bearer {}", api_token))
716 .header("Content-Type", "application/json")
717 .header("Copilot-Integration-Id", "vscode-chat")
718 .header("Editor-Version", "vscode/1.103.2")
719 .header("x-github-api-version", "2025-05-01");
720
721 let request = request_builder.body(AsyncBody::empty())?;
722
723 let mut response = client.send(request).await?;
724
725 anyhow::ensure!(
726 response.status().is_success(),
727 "Failed to request models: {}",
728 response.status()
729 );
730 let mut body = Vec::new();
731 response.body_mut().read_to_end(&mut body).await?;
732
733 let body_str = std::str::from_utf8(&body)?;
734
735 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
736
737 Ok(models)
738}
739
740async fn request_api_token(
741 oauth_token: &str,
742 auth_url: Arc<str>,
743 client: Arc<dyn HttpClient>,
744) -> Result<ApiToken> {
745 let request_builder = HttpRequest::builder()
746 .method(Method::GET)
747 .uri(auth_url.as_ref())
748 .header("Authorization", format!("token {}", oauth_token))
749 .header("Accept", "application/json");
750
751 let request = request_builder.body(AsyncBody::empty())?;
752
753 let mut response = client.send(request).await?;
754
755 if response.status().is_success() {
756 let mut body = Vec::new();
757 response.body_mut().read_to_end(&mut body).await?;
758
759 let body_str = std::str::from_utf8(&body)?;
760
761 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
762 ApiToken::try_from(parsed)
763 } else {
764 let mut body = Vec::new();
765 response.body_mut().read_to_end(&mut body).await?;
766
767 let body_str = std::str::from_utf8(&body)?;
768 anyhow::bail!("Failed to request API token: {body_str}");
769 }
770}
771
772fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
773 serde_json::from_str::<serde_json::Value>(&contents)
774 .map(|v| {
775 v.as_object().and_then(|obj| {
776 obj.iter().find_map(|(key, value)| {
777 if key.starts_with(domain) {
778 value["oauth_token"].as_str().map(|v| v.to_string())
779 } else {
780 None
781 }
782 })
783 })
784 })
785 .ok()
786 .flatten()
787}
788
789async fn stream_completion(
790 client: Arc<dyn HttpClient>,
791 api_key: String,
792 completion_url: Arc<str>,
793 request: Request,
794 is_user_initiated: bool,
795) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
796 let is_vision_request = request.messages.iter().any(|message| match message {
797 ChatMessage::User { content }
798 | ChatMessage::Assistant { content, .. }
799 | ChatMessage::Tool { content, .. } => {
800 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
801 }
802 _ => false,
803 });
804
805 let request_initiator = if is_user_initiated { "user" } else { "agent" };
806
807 let request_builder = HttpRequest::builder()
808 .method(Method::POST)
809 .uri(completion_url.as_ref())
810 .header(
811 "Editor-Version",
812 format!(
813 "Zed/{}",
814 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
815 ),
816 )
817 .header("Authorization", format!("Bearer {}", api_key))
818 .header("Content-Type", "application/json")
819 .header("Copilot-Integration-Id", "vscode-chat")
820 .header("X-Initiator", request_initiator)
821 .when(is_vision_request, |builder| {
822 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
823 });
824
825 let is_streaming = request.stream;
826
827 let json = serde_json::to_string(&request)?;
828 let request = request_builder.body(AsyncBody::from(json))?;
829 let mut response = client.send(request).await?;
830
831 if !response.status().is_success() {
832 let mut body = Vec::new();
833 response.body_mut().read_to_end(&mut body).await?;
834 let body_str = std::str::from_utf8(&body)?;
835 anyhow::bail!(
836 "Failed to connect to API: {} {}",
837 response.status(),
838 body_str
839 );
840 }
841
842 if is_streaming {
843 let reader = BufReader::new(response.into_body());
844 Ok(reader
845 .lines()
846 .filter_map(|line| async move {
847 match line {
848 Ok(line) => {
849 let line = line.strip_prefix("data: ")?;
850 if line.starts_with("[DONE]") {
851 return None;
852 }
853
854 match serde_json::from_str::<ResponseEvent>(line) {
855 Ok(response) => {
856 if response.choices.is_empty() {
857 None
858 } else {
859 Some(Ok(response))
860 }
861 }
862 Err(error) => Some(Err(anyhow!(error))),
863 }
864 }
865 Err(error) => Some(Err(anyhow!(error))),
866 }
867 })
868 .boxed())
869 } else {
870 let mut body = Vec::new();
871 response.body_mut().read_to_end(&mut body).await?;
872 let body_str = std::str::from_utf8(&body)?;
873 let response: ResponseEvent = serde_json::from_str(body_str)?;
874
875 Ok(futures::stream::once(async move { Ok(response) }).boxed())
876 }
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_resilient_model_schema_deserialize() {
885 let json = r#"{
886 "data": [
887 {
888 "billing": {
889 "is_premium": false,
890 "multiplier": 0
891 },
892 "capabilities": {
893 "family": "gpt-4",
894 "limits": {
895 "max_context_window_tokens": 32768,
896 "max_output_tokens": 4096,
897 "max_prompt_tokens": 32768
898 },
899 "object": "model_capabilities",
900 "supports": { "streaming": true, "tool_calls": true },
901 "tokenizer": "cl100k_base",
902 "type": "chat"
903 },
904 "id": "gpt-4",
905 "is_chat_default": false,
906 "is_chat_fallback": false,
907 "model_picker_enabled": false,
908 "name": "GPT 4",
909 "object": "model",
910 "preview": false,
911 "vendor": "Azure OpenAI",
912 "version": "gpt-4-0613"
913 },
914 {
915 "some-unknown-field": 123
916 },
917 {
918 "billing": {
919 "is_premium": true,
920 "multiplier": 1,
921 "restricted_to": [
922 "pro",
923 "pro_plus",
924 "business",
925 "enterprise"
926 ]
927 },
928 "capabilities": {
929 "family": "claude-3.7-sonnet",
930 "limits": {
931 "max_context_window_tokens": 200000,
932 "max_output_tokens": 16384,
933 "max_prompt_tokens": 90000,
934 "vision": {
935 "max_prompt_image_size": 3145728,
936 "max_prompt_images": 1,
937 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
938 }
939 },
940 "object": "model_capabilities",
941 "supports": {
942 "parallel_tool_calls": true,
943 "streaming": true,
944 "tool_calls": true,
945 "vision": true
946 },
947 "tokenizer": "o200k_base",
948 "type": "chat"
949 },
950 "id": "claude-3.7-sonnet",
951 "is_chat_default": false,
952 "is_chat_fallback": false,
953 "model_picker_enabled": true,
954 "name": "Claude 3.7 Sonnet",
955 "object": "model",
956 "policy": {
957 "state": "enabled",
958 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
959 },
960 "preview": false,
961 "vendor": "Anthropic",
962 "version": "claude-3.7-sonnet"
963 }
964 ],
965 "object": "list"
966 }"#;
967
968 let schema: ModelSchema = serde_json::from_str(json).unwrap();
969
970 assert_eq!(schema.data.len(), 2);
971 assert_eq!(schema.data[0].id, "gpt-4");
972 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
973 }
974
975 #[test]
976 fn test_unknown_vendor_resilience() {
977 let json = r#"{
978 "data": [
979 {
980 "billing": {
981 "is_premium": false,
982 "multiplier": 1
983 },
984 "capabilities": {
985 "family": "future-model",
986 "limits": {
987 "max_context_window_tokens": 128000,
988 "max_output_tokens": 8192,
989 "max_prompt_tokens": 120000
990 },
991 "object": "model_capabilities",
992 "supports": { "streaming": true, "tool_calls": true },
993 "type": "chat"
994 },
995 "id": "future-model-v1",
996 "is_chat_default": false,
997 "is_chat_fallback": false,
998 "model_picker_enabled": true,
999 "name": "Future Model v1",
1000 "object": "model",
1001 "preview": false,
1002 "vendor": "SomeNewVendor",
1003 "version": "v1.0"
1004 }
1005 ],
1006 "object": "list"
1007 }"#;
1008
1009 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1010
1011 assert_eq!(schema.data.len(), 1);
1012 assert_eq!(schema.data[0].id, "future-model-v1");
1013 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1014 }
1015
1016 #[test]
1017 fn test_models_with_pending_policy_deserialize() {
1018 // This test verifies that models with policy states other than "enabled"
1019 // (such as "pending" or "requires_consent") are properly deserialized.
1020 // Note: These models will be filtered out by get_models() and won't appear
1021 // in the model picker until the user enables them on GitHub.
1022 let json = r#"{
1023 "data": [
1024 {
1025 "billing": { "is_premium": true, "multiplier": 1 },
1026 "capabilities": {
1027 "family": "claude-sonnet-4",
1028 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1029 "object": "model_capabilities",
1030 "supports": { "streaming": true, "tool_calls": true },
1031 "type": "chat"
1032 },
1033 "id": "claude-sonnet-4",
1034 "is_chat_default": false,
1035 "is_chat_fallback": false,
1036 "model_picker_enabled": true,
1037 "name": "Claude Sonnet 4",
1038 "object": "model",
1039 "policy": {
1040 "state": "pending",
1041 "terms": "Enable access to Claude models from Anthropic."
1042 },
1043 "preview": false,
1044 "vendor": "Anthropic",
1045 "version": "claude-sonnet-4"
1046 },
1047 {
1048 "billing": { "is_premium": true, "multiplier": 1 },
1049 "capabilities": {
1050 "family": "claude-opus-4",
1051 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1052 "object": "model_capabilities",
1053 "supports": { "streaming": true, "tool_calls": true },
1054 "type": "chat"
1055 },
1056 "id": "claude-opus-4",
1057 "is_chat_default": false,
1058 "is_chat_fallback": false,
1059 "model_picker_enabled": true,
1060 "name": "Claude Opus 4",
1061 "object": "model",
1062 "policy": {
1063 "state": "requires_consent",
1064 "terms": "Enable access to Claude models from Anthropic."
1065 },
1066 "preview": false,
1067 "vendor": "Anthropic",
1068 "version": "claude-opus-4"
1069 }
1070 ],
1071 "object": "list"
1072 }"#;
1073
1074 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1075
1076 // Both models should deserialize successfully (filtering happens in get_models)
1077 assert_eq!(schema.data.len(), 2);
1078 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1079 assert_eq!(schema.data[1].id, "claude-opus-4");
1080 }
1081
1082 #[test]
1083 fn test_multiple_anthropic_models_preserved() {
1084 // This test verifies that multiple Claude models from Anthropic
1085 // are all preserved and not incorrectly deduplicated.
1086 // This was the root cause of issue #47540.
1087 let json = r#"{
1088 "data": [
1089 {
1090 "billing": { "is_premium": true, "multiplier": 1 },
1091 "capabilities": {
1092 "family": "claude-sonnet-4",
1093 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1094 "object": "model_capabilities",
1095 "supports": { "streaming": true, "tool_calls": true },
1096 "type": "chat"
1097 },
1098 "id": "claude-sonnet-4",
1099 "is_chat_default": false,
1100 "is_chat_fallback": false,
1101 "model_picker_enabled": true,
1102 "name": "Claude Sonnet 4",
1103 "object": "model",
1104 "preview": false,
1105 "vendor": "Anthropic",
1106 "version": "claude-sonnet-4"
1107 },
1108 {
1109 "billing": { "is_premium": true, "multiplier": 1 },
1110 "capabilities": {
1111 "family": "claude-opus-4",
1112 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1113 "object": "model_capabilities",
1114 "supports": { "streaming": true, "tool_calls": true },
1115 "type": "chat"
1116 },
1117 "id": "claude-opus-4",
1118 "is_chat_default": false,
1119 "is_chat_fallback": false,
1120 "model_picker_enabled": true,
1121 "name": "Claude Opus 4",
1122 "object": "model",
1123 "preview": false,
1124 "vendor": "Anthropic",
1125 "version": "claude-opus-4"
1126 },
1127 {
1128 "billing": { "is_premium": true, "multiplier": 1 },
1129 "capabilities": {
1130 "family": "claude-sonnet-4.5",
1131 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1132 "object": "model_capabilities",
1133 "supports": { "streaming": true, "tool_calls": true },
1134 "type": "chat"
1135 },
1136 "id": "claude-sonnet-4.5",
1137 "is_chat_default": false,
1138 "is_chat_fallback": false,
1139 "model_picker_enabled": true,
1140 "name": "Claude Sonnet 4.5",
1141 "object": "model",
1142 "preview": false,
1143 "vendor": "Anthropic",
1144 "version": "claude-sonnet-4.5"
1145 }
1146 ],
1147 "object": "list"
1148 }"#;
1149
1150 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1151
1152 // All three Anthropic models should be preserved
1153 assert_eq!(schema.data.len(), 3);
1154 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1155 assert_eq!(schema.data[1].id, "claude-opus-4");
1156 assert_eq!(schema.data[2].id, "claude-sonnet-4.5");
1157 }
1158
1159 #[test]
1160 fn test_models_with_same_family_both_preserved() {
1161 // Test that models sharing the same family (e.g., thinking variants)
1162 // are both preserved in the model list.
1163 let json = r#"{
1164 "data": [
1165 {
1166 "billing": { "is_premium": true, "multiplier": 1 },
1167 "capabilities": {
1168 "family": "claude-sonnet-4",
1169 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1170 "object": "model_capabilities",
1171 "supports": { "streaming": true, "tool_calls": true },
1172 "type": "chat"
1173 },
1174 "id": "claude-sonnet-4",
1175 "is_chat_default": false,
1176 "is_chat_fallback": false,
1177 "model_picker_enabled": true,
1178 "name": "Claude Sonnet 4",
1179 "object": "model",
1180 "preview": false,
1181 "vendor": "Anthropic",
1182 "version": "claude-sonnet-4"
1183 },
1184 {
1185 "billing": { "is_premium": true, "multiplier": 1 },
1186 "capabilities": {
1187 "family": "claude-sonnet-4",
1188 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1189 "object": "model_capabilities",
1190 "supports": { "streaming": true, "tool_calls": true },
1191 "type": "chat"
1192 },
1193 "id": "claude-sonnet-4-thinking",
1194 "is_chat_default": false,
1195 "is_chat_fallback": false,
1196 "model_picker_enabled": true,
1197 "name": "Claude Sonnet 4 (Thinking)",
1198 "object": "model",
1199 "preview": false,
1200 "vendor": "Anthropic",
1201 "version": "claude-sonnet-4-thinking"
1202 }
1203 ],
1204 "object": "list"
1205 }"#;
1206
1207 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1208
1209 // Both models should be preserved even though they share the same family
1210 assert_eq!(schema.data.len(), 2);
1211 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1212 assert_eq!(schema.data[1].id, "claude-sonnet-4-thinking");
1213 }
1214
1215 #[test]
1216 fn test_mixed_vendor_models_all_preserved() {
1217 // Test that models from different vendors are all preserved.
1218 let json = r#"{
1219 "data": [
1220 {
1221 "billing": { "is_premium": false, "multiplier": 1 },
1222 "capabilities": {
1223 "family": "gpt-4o",
1224 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1225 "object": "model_capabilities",
1226 "supports": { "streaming": true, "tool_calls": true },
1227 "type": "chat"
1228 },
1229 "id": "gpt-4o",
1230 "is_chat_default": true,
1231 "is_chat_fallback": false,
1232 "model_picker_enabled": true,
1233 "name": "GPT-4o",
1234 "object": "model",
1235 "preview": false,
1236 "vendor": "Azure OpenAI",
1237 "version": "gpt-4o"
1238 },
1239 {
1240 "billing": { "is_premium": true, "multiplier": 1 },
1241 "capabilities": {
1242 "family": "claude-sonnet-4",
1243 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1244 "object": "model_capabilities",
1245 "supports": { "streaming": true, "tool_calls": true },
1246 "type": "chat"
1247 },
1248 "id": "claude-sonnet-4",
1249 "is_chat_default": false,
1250 "is_chat_fallback": false,
1251 "model_picker_enabled": true,
1252 "name": "Claude Sonnet 4",
1253 "object": "model",
1254 "preview": false,
1255 "vendor": "Anthropic",
1256 "version": "claude-sonnet-4"
1257 },
1258 {
1259 "billing": { "is_premium": true, "multiplier": 1 },
1260 "capabilities": {
1261 "family": "gemini-2.0-flash",
1262 "limits": { "max_context_window_tokens": 1000000, "max_output_tokens": 8192, "max_prompt_tokens": 900000 },
1263 "object": "model_capabilities",
1264 "supports": { "streaming": true, "tool_calls": true },
1265 "type": "chat"
1266 },
1267 "id": "gemini-2.0-flash",
1268 "is_chat_default": false,
1269 "is_chat_fallback": false,
1270 "model_picker_enabled": true,
1271 "name": "Gemini 2.0 Flash",
1272 "object": "model",
1273 "preview": false,
1274 "vendor": "Google",
1275 "version": "gemini-2.0-flash"
1276 }
1277 ],
1278 "object": "list"
1279 }"#;
1280
1281 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1282
1283 // All three models from different vendors should be preserved
1284 assert_eq!(schema.data.len(), 3);
1285 assert_eq!(schema.data[0].id, "gpt-4o");
1286 assert_eq!(schema.data[1].id, "claude-sonnet-4");
1287 assert_eq!(schema.data[2].id, "gemini-2.0-flash");
1288 }
1289
1290 #[test]
1291 fn test_model_with_messages_endpoint_deserializes() {
1292 // Anthropic Claude models use /v1/messages endpoint.
1293 // This test verifies such models deserialize correctly (issue #47540 root cause).
1294 let json = r#"{
1295 "data": [
1296 {
1297 "billing": { "is_premium": true, "multiplier": 1 },
1298 "capabilities": {
1299 "family": "claude-sonnet-4",
1300 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1301 "object": "model_capabilities",
1302 "supports": { "streaming": true, "tool_calls": true },
1303 "type": "chat"
1304 },
1305 "id": "claude-sonnet-4",
1306 "is_chat_default": false,
1307 "is_chat_fallback": false,
1308 "model_picker_enabled": true,
1309 "name": "Claude Sonnet 4",
1310 "object": "model",
1311 "preview": false,
1312 "vendor": "Anthropic",
1313 "version": "claude-sonnet-4",
1314 "supported_endpoints": ["/v1/messages"]
1315 }
1316 ],
1317 "object": "list"
1318 }"#;
1319
1320 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1321
1322 assert_eq!(schema.data.len(), 1);
1323 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1324 assert_eq!(
1325 schema.data[0].supported_endpoints,
1326 vec![ModelSupportedEndpoint::Messages]
1327 );
1328 }
1329
1330 #[test]
1331 fn test_model_with_unknown_endpoint_deserializes() {
1332 // Future-proofing: unknown endpoints should deserialize to Unknown variant
1333 // instead of causing the entire model to fail deserialization.
1334 let json = r#"{
1335 "data": [
1336 {
1337 "billing": { "is_premium": false, "multiplier": 1 },
1338 "capabilities": {
1339 "family": "future-model",
1340 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 8192, "max_prompt_tokens": 120000 },
1341 "object": "model_capabilities",
1342 "supports": { "streaming": true, "tool_calls": true },
1343 "type": "chat"
1344 },
1345 "id": "future-model-v2",
1346 "is_chat_default": false,
1347 "is_chat_fallback": false,
1348 "model_picker_enabled": true,
1349 "name": "Future Model v2",
1350 "object": "model",
1351 "preview": false,
1352 "vendor": "OpenAI",
1353 "version": "v2.0",
1354 "supported_endpoints": ["/v2/completions", "/chat/completions"]
1355 }
1356 ],
1357 "object": "list"
1358 }"#;
1359
1360 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1361
1362 assert_eq!(schema.data.len(), 1);
1363 assert_eq!(schema.data[0].id, "future-model-v2");
1364 assert_eq!(
1365 schema.data[0].supported_endpoints,
1366 vec![
1367 ModelSupportedEndpoint::Unknown,
1368 ModelSupportedEndpoint::ChatCompletions
1369 ]
1370 );
1371 }
1372
1373 #[test]
1374 fn test_model_with_multiple_endpoints() {
1375 // Test model with multiple supported endpoints (common for newer models).
1376 let json = r#"{
1377 "data": [
1378 {
1379 "billing": { "is_premium": true, "multiplier": 1 },
1380 "capabilities": {
1381 "family": "gpt-4o",
1382 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1383 "object": "model_capabilities",
1384 "supports": { "streaming": true, "tool_calls": true },
1385 "type": "chat"
1386 },
1387 "id": "gpt-4o",
1388 "is_chat_default": true,
1389 "is_chat_fallback": false,
1390 "model_picker_enabled": true,
1391 "name": "GPT-4o",
1392 "object": "model",
1393 "preview": false,
1394 "vendor": "OpenAI",
1395 "version": "gpt-4o",
1396 "supported_endpoints": ["/chat/completions", "/responses"]
1397 }
1398 ],
1399 "object": "list"
1400 }"#;
1401
1402 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1403
1404 assert_eq!(schema.data.len(), 1);
1405 assert_eq!(schema.data[0].id, "gpt-4o");
1406 assert_eq!(
1407 schema.data[0].supported_endpoints,
1408 vec![
1409 ModelSupportedEndpoint::ChatCompletions,
1410 ModelSupportedEndpoint::Responses
1411 ]
1412 );
1413 }
1414
1415 #[test]
1416 fn test_supports_response_method() {
1417 // Test the supports_response() method which determines endpoint routing.
1418 let model_with_responses_only = Model {
1419 billing: ModelBilling {
1420 is_premium: false,
1421 multiplier: 1.0,
1422 restricted_to: None,
1423 },
1424 capabilities: ModelCapabilities {
1425 family: "test".to_string(),
1426 limits: ModelLimits::default(),
1427 supports: ModelSupportedFeatures {
1428 streaming: true,
1429 tool_calls: true,
1430 parallel_tool_calls: false,
1431 vision: false,
1432 },
1433 model_type: "chat".to_string(),
1434 tokenizer: None,
1435 },
1436 id: "test-model".to_string(),
1437 name: "Test Model".to_string(),
1438 policy: None,
1439 vendor: ModelVendor::OpenAI,
1440 is_chat_default: false,
1441 is_chat_fallback: false,
1442 model_picker_enabled: true,
1443 supported_endpoints: vec![ModelSupportedEndpoint::Responses],
1444 };
1445
1446 let model_with_chat_completions = Model {
1447 supported_endpoints: vec![ModelSupportedEndpoint::ChatCompletions],
1448 ..model_with_responses_only.clone()
1449 };
1450
1451 let model_with_both = Model {
1452 supported_endpoints: vec![
1453 ModelSupportedEndpoint::ChatCompletions,
1454 ModelSupportedEndpoint::Responses,
1455 ],
1456 ..model_with_responses_only.clone()
1457 };
1458
1459 let model_with_messages = Model {
1460 supported_endpoints: vec![ModelSupportedEndpoint::Messages],
1461 ..model_with_responses_only.clone()
1462 };
1463
1464 // Only /responses endpoint -> supports_response = true
1465 assert!(model_with_responses_only.supports_response());
1466
1467 // Only /chat/completions endpoint -> supports_response = false
1468 assert!(!model_with_chat_completions.supports_response());
1469
1470 // Both endpoints (has /chat/completions) -> supports_response = false
1471 assert!(!model_with_both.supports_response());
1472
1473 // Only /v1/messages endpoint -> supports_response = false (doesn't have /responses)
1474 assert!(!model_with_messages.supports_response());
1475 }
1476}