1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::Context as _;
6use anyhow::{Result, anyhow};
7use chrono::DateTime;
8use collections::HashSet;
9use fs::Fs;
10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
11use gpui::WeakEntity;
12use gpui::{App, AsyncApp, Global, prelude::*};
13use http_client::HttpRequestExt;
14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
15use itertools::Itertools;
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use crate::copilot_responses as responses;
20use settings::watch_config_dir;
21
22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn token_url(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 let domain = Self::parse_domain(enterprise_uri);
33 format!("https://api.{}/copilot_internal/v2/token", domain)
34 } else {
35 "https://api.github.com/copilot_internal/v2/token".to_string()
36 }
37 }
38
39 pub fn oauth_domain(&self) -> String {
40 if let Some(enterprise_uri) = &self.enterprise_uri {
41 Self::parse_domain(enterprise_uri)
42 } else {
43 "github.com".to_string()
44 }
45 }
46
47 pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
48 format!("{}/chat/completions", endpoint)
49 }
50
51 pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
52 format!("{}/responses", endpoint)
53 }
54
55 pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
56 format!("{}/models", endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86}
87
88#[derive(Deserialize)]
89struct ModelSchema {
90 #[serde(deserialize_with = "deserialize_models_skip_errors")]
91 data: Vec<Model>,
92}
93
94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
95where
96 D: serde::Deserializer<'de>,
97{
98 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
99 let models = raw_values
100 .into_iter()
101 .filter_map(|value| match serde_json::from_value::<Model>(value) {
102 Ok(model) => Some(model),
103 Err(err) => {
104 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
105 None
106 }
107 })
108 .collect();
109
110 Ok(models)
111}
112
113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
114pub struct Model {
115 billing: ModelBilling,
116 capabilities: ModelCapabilities,
117 id: String,
118 name: String,
119 policy: Option<ModelPolicy>,
120 vendor: ModelVendor,
121 is_chat_default: bool,
122 // The model with this value true is selected by VSCode copilot if a premium request limit is
123 // reached. Zed does not currently implement this behaviour
124 is_chat_fallback: bool,
125 model_picker_enabled: bool,
126 #[serde(default)]
127 supported_endpoints: Vec<ModelSupportedEndpoint>,
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
131struct ModelBilling {
132 is_premium: bool,
133 multiplier: f64,
134 // List of plans a model is restricted to
135 // Field is not present if a model is available for all plans
136 #[serde(default)]
137 restricted_to: Option<Vec<String>>,
138}
139
140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
141struct ModelCapabilities {
142 family: String,
143 #[serde(default)]
144 limits: ModelLimits,
145 supports: ModelSupportedFeatures,
146 #[serde(rename = "type")]
147 model_type: String,
148 #[serde(default)]
149 tokenizer: Option<String>,
150}
151
152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
153struct ModelLimits {
154 #[serde(default)]
155 max_context_window_tokens: usize,
156 #[serde(default)]
157 max_output_tokens: usize,
158 #[serde(default)]
159 max_prompt_tokens: u64,
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
163struct ModelPolicy {
164 state: String,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelSupportedFeatures {
169 #[serde(default)]
170 streaming: bool,
171 #[serde(default)]
172 tool_calls: bool,
173 #[serde(default)]
174 parallel_tool_calls: bool,
175 #[serde(default)]
176 vision: bool,
177}
178
179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
180pub enum ModelVendor {
181 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
182 #[serde(alias = "Azure OpenAI")]
183 OpenAI,
184 Google,
185 Anthropic,
186 #[serde(rename = "xAI")]
187 XAI,
188 /// Unknown vendor that we don't explicitly support yet
189 #[serde(other)]
190 Unknown,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
194#[serde(tag = "type")]
195pub enum ChatMessagePart {
196 #[serde(rename = "text")]
197 Text { text: String },
198 #[serde(rename = "image_url")]
199 Image { image_url: ImageUrl },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
203pub struct ImageUrl {
204 pub url: String,
205}
206
207impl Model {
208 pub fn uses_streaming(&self) -> bool {
209 self.capabilities.supports.streaming
210 }
211
212 pub fn id(&self) -> &str {
213 self.id.as_str()
214 }
215
216 pub fn display_name(&self) -> &str {
217 self.name.as_str()
218 }
219
220 pub fn max_token_count(&self) -> u64 {
221 self.capabilities.limits.max_prompt_tokens
222 }
223
224 pub fn supports_tools(&self) -> bool {
225 self.capabilities.supports.tool_calls
226 }
227
228 pub fn vendor(&self) -> ModelVendor {
229 self.vendor
230 }
231
232 pub fn supports_vision(&self) -> bool {
233 self.capabilities.supports.vision
234 }
235
236 pub fn supports_parallel_tool_calls(&self) -> bool {
237 self.capabilities.supports.parallel_tool_calls
238 }
239
240 pub fn tokenizer(&self) -> Option<&str> {
241 self.capabilities.tokenizer.as_deref()
242 }
243
244 pub fn supports_response(&self) -> bool {
245 self.supported_endpoints.len() > 0
246 && !self
247 .supported_endpoints
248 .contains(&ModelSupportedEndpoint::ChatCompletions)
249 && self
250 .supported_endpoints
251 .contains(&ModelSupportedEndpoint::Responses)
252 }
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct Request {
257 pub intent: bool,
258 pub n: usize,
259 pub stream: bool,
260 pub temperature: f32,
261 pub model: String,
262 pub messages: Vec<ChatMessage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<Tool>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub tool_choice: Option<ToolChoice>,
267}
268
269#[derive(Serialize, Deserialize)]
270pub struct Function {
271 pub name: String,
272 pub description: String,
273 pub parameters: serde_json::Value,
274}
275
276#[derive(Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum Tool {
279 Function { function: Function },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolChoice {
285 Auto,
286 Any,
287 None,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291#[serde(tag = "role", rename_all = "lowercase")]
292pub enum ChatMessage {
293 Assistant {
294 content: ChatMessageContent,
295 #[serde(default, skip_serializing_if = "Vec::is_empty")]
296 tool_calls: Vec<ToolCall>,
297 },
298 User {
299 content: ChatMessageContent,
300 },
301 System {
302 content: String,
303 },
304 Tool {
305 content: ChatMessageContent,
306 tool_call_id: String,
307 },
308}
309
310#[derive(Debug, Serialize, Deserialize)]
311#[serde(untagged)]
312pub enum ChatMessageContent {
313 Plain(String),
314 Multipart(Vec<ChatMessagePart>),
315}
316
317impl ChatMessageContent {
318 pub fn empty() -> Self {
319 ChatMessageContent::Multipart(vec![])
320 }
321}
322
323impl From<Vec<ChatMessagePart>> for ChatMessageContent {
324 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
325 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
326 ChatMessageContent::Plain(std::mem::take(text))
327 } else {
328 ChatMessageContent::Multipart(parts)
329 }
330 }
331}
332
333impl From<String> for ChatMessageContent {
334 fn from(text: String) -> Self {
335 ChatMessageContent::Plain(text)
336 }
337}
338
339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
340pub struct ToolCall {
341 pub id: String,
342 #[serde(flatten)]
343 pub content: ToolCallContent,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347#[serde(tag = "type", rename_all = "lowercase")]
348pub enum ToolCallContent {
349 Function { function: FunctionContent },
350}
351
352#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
353pub struct FunctionContent {
354 pub name: String,
355 pub arguments: String,
356 #[serde(default, skip_serializing_if = "Option::is_none")]
357 pub thought_signature: Option<String>,
358}
359
360#[derive(Deserialize, Debug)]
361#[serde(tag = "type", rename_all = "snake_case")]
362pub struct ResponseEvent {
363 pub choices: Vec<ResponseChoice>,
364 pub id: String,
365 pub usage: Option<Usage>,
366}
367
368#[derive(Deserialize, Debug)]
369pub struct Usage {
370 pub completion_tokens: u64,
371 pub prompt_tokens: u64,
372 pub total_tokens: u64,
373}
374
375#[derive(Debug, Deserialize)]
376pub struct ResponseChoice {
377 pub index: Option<usize>,
378 pub finish_reason: Option<String>,
379 pub delta: Option<ResponseDelta>,
380 pub message: Option<ResponseDelta>,
381}
382
383#[derive(Debug, Deserialize)]
384pub struct ResponseDelta {
385 pub content: Option<String>,
386 pub role: Option<Role>,
387 #[serde(default)]
388 pub tool_calls: Vec<ToolCallChunk>,
389}
390#[derive(Deserialize, Debug, Eq, PartialEq)]
391pub struct ToolCallChunk {
392 pub index: Option<usize>,
393 pub id: Option<String>,
394 pub function: Option<FunctionChunk>,
395}
396
397#[derive(Deserialize, Debug, Eq, PartialEq)]
398pub struct FunctionChunk {
399 pub name: Option<String>,
400 pub arguments: Option<String>,
401 pub thought_signature: Option<String>,
402}
403
404#[derive(Deserialize)]
405struct ApiTokenResponse {
406 token: String,
407 expires_at: i64,
408 endpoints: ApiTokenResponseEndpoints,
409}
410
411#[derive(Deserialize)]
412struct ApiTokenResponseEndpoints {
413 api: String,
414}
415
416#[derive(Clone)]
417struct ApiToken {
418 api_key: String,
419 expires_at: DateTime<chrono::Utc>,
420 api_endpoint: String,
421}
422
423impl ApiToken {
424 pub fn remaining_seconds(&self) -> i64 {
425 self.expires_at
426 .timestamp()
427 .saturating_sub(chrono::Utc::now().timestamp())
428 }
429}
430
431impl TryFrom<ApiTokenResponse> for ApiToken {
432 type Error = anyhow::Error;
433
434 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
435 let expires_at =
436 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
437
438 Ok(Self {
439 api_key: response.token,
440 expires_at,
441 api_endpoint: response.endpoints.api,
442 })
443 }
444}
445
446struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
447
448impl Global for GlobalCopilotChat {}
449
450pub struct CopilotChat {
451 oauth_token: Option<String>,
452 api_token: Option<ApiToken>,
453 configuration: CopilotChatConfiguration,
454 models: Option<Vec<Model>>,
455 client: Arc<dyn HttpClient>,
456}
457
458pub fn init(
459 fs: Arc<dyn Fs>,
460 client: Arc<dyn HttpClient>,
461 configuration: CopilotChatConfiguration,
462 cx: &mut App,
463) {
464 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
465 cx.set_global(GlobalCopilotChat(copilot_chat));
466}
467
468pub fn copilot_chat_config_dir() -> &'static PathBuf {
469 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
470
471 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
472 let config_dir = if cfg!(target_os = "windows") {
473 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
474 } else {
475 std::env::var("XDG_CONFIG_HOME")
476 .map(PathBuf::from)
477 .unwrap_or_else(|_| home_dir().join(".config"))
478 };
479
480 config_dir.join("github-copilot")
481 })
482}
483
484fn copilot_chat_config_paths() -> [PathBuf; 2] {
485 let base_dir = copilot_chat_config_dir();
486 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
487}
488
489impl CopilotChat {
490 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
491 cx.try_global::<GlobalCopilotChat>()
492 .map(|model| model.0.clone())
493 }
494
495 fn new(
496 fs: Arc<dyn Fs>,
497 client: Arc<dyn HttpClient>,
498 configuration: CopilotChatConfiguration,
499 cx: &mut Context<Self>,
500 ) -> Self {
501 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
502 let dir_path = copilot_chat_config_dir();
503
504 cx.spawn(async move |this, cx| {
505 let mut parent_watch_rx = watch_config_dir(
506 cx.background_executor(),
507 fs.clone(),
508 dir_path.clone(),
509 config_paths,
510 );
511 while let Some(contents) = parent_watch_rx.next().await {
512 let oauth_domain =
513 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
514 let oauth_token = extract_oauth_token(contents, &oauth_domain);
515
516 this.update(cx, |this, cx| {
517 this.oauth_token = oauth_token.clone();
518 cx.notify();
519 })?;
520
521 if oauth_token.is_some() {
522 Self::update_models(&this, cx).await?;
523 }
524 }
525 anyhow::Ok(())
526 })
527 .detach_and_log_err(cx);
528
529 let this = Self {
530 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
531 api_token: None,
532 models: None,
533 configuration,
534 client,
535 };
536
537 if this.oauth_token.is_some() {
538 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
539 .detach_and_log_err(cx);
540 }
541
542 this
543 }
544
545 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
546 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
547 (
548 this.oauth_token.clone(),
549 this.client.clone(),
550 this.configuration.clone(),
551 )
552 })?;
553
554 let oauth_token = oauth_token
555 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
556
557 let token_url = configuration.token_url();
558 let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
559
560 let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
561 let models =
562 get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
563
564 this.update(cx, |this, cx| {
565 this.api_token = Some(api_token);
566 this.models = Some(models);
567 cx.notify();
568 })?;
569 anyhow::Ok(())
570 }
571
572 pub fn is_authenticated(&self) -> bool {
573 self.oauth_token.is_some()
574 }
575
576 pub fn models(&self) -> Option<&[Model]> {
577 self.models.as_deref()
578 }
579
580 pub async fn stream_completion(
581 request: Request,
582 is_user_initiated: bool,
583 mut cx: AsyncApp,
584 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
585 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
586
587 let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
588 stream_completion(
589 client.clone(),
590 token.api_key,
591 api_url.into(),
592 request,
593 is_user_initiated,
594 )
595 .await
596 }
597
598 pub async fn stream_response(
599 request: responses::Request,
600 is_user_initiated: bool,
601 mut cx: AsyncApp,
602 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
603 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
604
605 let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
606 responses::stream_response(
607 client.clone(),
608 token.api_key,
609 api_url,
610 request,
611 is_user_initiated,
612 )
613 .await
614 }
615
616 async fn get_auth_details(
617 cx: &mut AsyncApp,
618 ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
619 let this = cx
620 .update(|cx| Self::global(cx))
621 .ok()
622 .flatten()
623 .context("Copilot chat is not enabled")?;
624
625 let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
626 (
627 this.oauth_token.clone(),
628 this.api_token.clone(),
629 this.client.clone(),
630 this.configuration.clone(),
631 )
632 })?;
633
634 let oauth_token = oauth_token.context("No OAuth token available")?;
635
636 let token = match api_token {
637 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
638 _ => {
639 let token_url = configuration.token_url();
640 let token =
641 request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
642 this.update(cx, |this, cx| {
643 this.api_token = Some(token.clone());
644 cx.notify();
645 })?;
646 token
647 }
648 };
649
650 Ok((client, token, configuration))
651 }
652
653 pub fn set_configuration(
654 &mut self,
655 configuration: CopilotChatConfiguration,
656 cx: &mut Context<Self>,
657 ) {
658 let same_configuration = self.configuration == configuration;
659 self.configuration = configuration;
660 if !same_configuration {
661 self.api_token = None;
662 cx.spawn(async move |this, cx| {
663 Self::update_models(&this, cx).await?;
664 Ok::<_, anyhow::Error>(())
665 })
666 .detach();
667 }
668 }
669}
670
671async fn get_models(
672 models_url: Arc<str>,
673 api_token: String,
674 client: Arc<dyn HttpClient>,
675) -> Result<Vec<Model>> {
676 let all_models = request_models(models_url, api_token, client).await?;
677
678 let mut models: Vec<Model> = all_models
679 .into_iter()
680 .filter(|model| {
681 model.model_picker_enabled
682 && model.capabilities.model_type.as_str() == "chat"
683 && model
684 .policy
685 .as_ref()
686 .is_none_or(|policy| policy.state == "enabled")
687 })
688 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
689 .collect();
690
691 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
692 let default_model = models.remove(default_model_position);
693 models.insert(0, default_model);
694 }
695
696 Ok(models)
697}
698
699async fn request_models(
700 models_url: Arc<str>,
701 api_token: String,
702 client: Arc<dyn HttpClient>,
703) -> Result<Vec<Model>> {
704 let request_builder = HttpRequest::builder()
705 .method(Method::GET)
706 .uri(models_url.as_ref())
707 .header("Authorization", format!("Bearer {}", api_token))
708 .header("Content-Type", "application/json")
709 .header("Copilot-Integration-Id", "vscode-chat")
710 .header("Editor-Version", "vscode/1.103.2")
711 .header("x-github-api-version", "2025-05-01");
712
713 let request = request_builder.body(AsyncBody::empty())?;
714
715 let mut response = client.send(request).await?;
716
717 anyhow::ensure!(
718 response.status().is_success(),
719 "Failed to request models: {}",
720 response.status()
721 );
722 let mut body = Vec::new();
723 response.body_mut().read_to_end(&mut body).await?;
724
725 let body_str = std::str::from_utf8(&body)?;
726
727 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
728
729 Ok(models)
730}
731
732async fn request_api_token(
733 oauth_token: &str,
734 auth_url: Arc<str>,
735 client: Arc<dyn HttpClient>,
736) -> Result<ApiToken> {
737 let request_builder = HttpRequest::builder()
738 .method(Method::GET)
739 .uri(auth_url.as_ref())
740 .header("Authorization", format!("token {}", oauth_token))
741 .header("Accept", "application/json");
742
743 let request = request_builder.body(AsyncBody::empty())?;
744
745 let mut response = client.send(request).await?;
746
747 if response.status().is_success() {
748 let mut body = Vec::new();
749 response.body_mut().read_to_end(&mut body).await?;
750
751 let body_str = std::str::from_utf8(&body)?;
752
753 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
754 ApiToken::try_from(parsed)
755 } else {
756 let mut body = Vec::new();
757 response.body_mut().read_to_end(&mut body).await?;
758
759 let body_str = std::str::from_utf8(&body)?;
760 anyhow::bail!("Failed to request API token: {body_str}");
761 }
762}
763
764fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
765 serde_json::from_str::<serde_json::Value>(&contents)
766 .map(|v| {
767 v.as_object().and_then(|obj| {
768 obj.iter().find_map(|(key, value)| {
769 if key.starts_with(domain) {
770 value["oauth_token"].as_str().map(|v| v.to_string())
771 } else {
772 None
773 }
774 })
775 })
776 })
777 .ok()
778 .flatten()
779}
780
781async fn stream_completion(
782 client: Arc<dyn HttpClient>,
783 api_key: String,
784 completion_url: Arc<str>,
785 request: Request,
786 is_user_initiated: bool,
787) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
788 let is_vision_request = request.messages.iter().any(|message| match message {
789 ChatMessage::User { content }
790 | ChatMessage::Assistant { content, .. }
791 | ChatMessage::Tool { content, .. } => {
792 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
793 }
794 _ => false,
795 });
796
797 let request_initiator = if is_user_initiated { "user" } else { "agent" };
798
799 let request_builder = HttpRequest::builder()
800 .method(Method::POST)
801 .uri(completion_url.as_ref())
802 .header(
803 "Editor-Version",
804 format!(
805 "Zed/{}",
806 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
807 ),
808 )
809 .header("Authorization", format!("Bearer {}", api_key))
810 .header("Content-Type", "application/json")
811 .header("Copilot-Integration-Id", "vscode-chat")
812 .header("X-Initiator", request_initiator)
813 .when(is_vision_request, |builder| {
814 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
815 });
816
817 let is_streaming = request.stream;
818
819 let json = serde_json::to_string(&request)?;
820 let request = request_builder.body(AsyncBody::from(json))?;
821 let mut response = client.send(request).await?;
822
823 if !response.status().is_success() {
824 let mut body = Vec::new();
825 response.body_mut().read_to_end(&mut body).await?;
826 let body_str = std::str::from_utf8(&body)?;
827 anyhow::bail!(
828 "Failed to connect to API: {} {}",
829 response.status(),
830 body_str
831 );
832 }
833
834 if is_streaming {
835 let reader = BufReader::new(response.into_body());
836 Ok(reader
837 .lines()
838 .filter_map(|line| async move {
839 match line {
840 Ok(line) => {
841 let line = line.strip_prefix("data: ")?;
842 if line.starts_with("[DONE]") {
843 return None;
844 }
845
846 match serde_json::from_str::<ResponseEvent>(line) {
847 Ok(response) => {
848 if response.choices.is_empty() {
849 None
850 } else {
851 Some(Ok(response))
852 }
853 }
854 Err(error) => Some(Err(anyhow!(error))),
855 }
856 }
857 Err(error) => Some(Err(anyhow!(error))),
858 }
859 })
860 .boxed())
861 } else {
862 let mut body = Vec::new();
863 response.body_mut().read_to_end(&mut body).await?;
864 let body_str = std::str::from_utf8(&body)?;
865 let response: ResponseEvent = serde_json::from_str(body_str)?;
866
867 Ok(futures::stream::once(async move { Ok(response) }).boxed())
868 }
869}
870
871#[cfg(test)]
872mod tests {
873 use super::*;
874
875 #[test]
876 fn test_resilient_model_schema_deserialize() {
877 let json = r#"{
878 "data": [
879 {
880 "billing": {
881 "is_premium": false,
882 "multiplier": 0
883 },
884 "capabilities": {
885 "family": "gpt-4",
886 "limits": {
887 "max_context_window_tokens": 32768,
888 "max_output_tokens": 4096,
889 "max_prompt_tokens": 32768
890 },
891 "object": "model_capabilities",
892 "supports": { "streaming": true, "tool_calls": true },
893 "tokenizer": "cl100k_base",
894 "type": "chat"
895 },
896 "id": "gpt-4",
897 "is_chat_default": false,
898 "is_chat_fallback": false,
899 "model_picker_enabled": false,
900 "name": "GPT 4",
901 "object": "model",
902 "preview": false,
903 "vendor": "Azure OpenAI",
904 "version": "gpt-4-0613"
905 },
906 {
907 "some-unknown-field": 123
908 },
909 {
910 "billing": {
911 "is_premium": true,
912 "multiplier": 1,
913 "restricted_to": [
914 "pro",
915 "pro_plus",
916 "business",
917 "enterprise"
918 ]
919 },
920 "capabilities": {
921 "family": "claude-3.7-sonnet",
922 "limits": {
923 "max_context_window_tokens": 200000,
924 "max_output_tokens": 16384,
925 "max_prompt_tokens": 90000,
926 "vision": {
927 "max_prompt_image_size": 3145728,
928 "max_prompt_images": 1,
929 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
930 }
931 },
932 "object": "model_capabilities",
933 "supports": {
934 "parallel_tool_calls": true,
935 "streaming": true,
936 "tool_calls": true,
937 "vision": true
938 },
939 "tokenizer": "o200k_base",
940 "type": "chat"
941 },
942 "id": "claude-3.7-sonnet",
943 "is_chat_default": false,
944 "is_chat_fallback": false,
945 "model_picker_enabled": true,
946 "name": "Claude 3.7 Sonnet",
947 "object": "model",
948 "policy": {
949 "state": "enabled",
950 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
951 },
952 "preview": false,
953 "vendor": "Anthropic",
954 "version": "claude-3.7-sonnet"
955 }
956 ],
957 "object": "list"
958 }"#;
959
960 let schema: ModelSchema = serde_json::from_str(json).unwrap();
961
962 assert_eq!(schema.data.len(), 2);
963 assert_eq!(schema.data[0].id, "gpt-4");
964 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
965 }
966
967 #[test]
968 fn test_unknown_vendor_resilience() {
969 let json = r#"{
970 "data": [
971 {
972 "billing": {
973 "is_premium": false,
974 "multiplier": 1
975 },
976 "capabilities": {
977 "family": "future-model",
978 "limits": {
979 "max_context_window_tokens": 128000,
980 "max_output_tokens": 8192,
981 "max_prompt_tokens": 120000
982 },
983 "object": "model_capabilities",
984 "supports": { "streaming": true, "tool_calls": true },
985 "type": "chat"
986 },
987 "id": "future-model-v1",
988 "is_chat_default": false,
989 "is_chat_fallback": false,
990 "model_picker_enabled": true,
991 "name": "Future Model v1",
992 "object": "model",
993 "preview": false,
994 "vendor": "SomeNewVendor",
995 "version": "v1.0"
996 }
997 ],
998 "object": "list"
999 }"#;
1000
1001 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1002
1003 assert_eq!(schema.data.len(), 1);
1004 assert_eq!(schema.data[0].id, "future-model-v1");
1005 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1006 }
1007}