1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::Context as _;
6use anyhow::{Result, anyhow};
7use chrono::DateTime;
8use collections::HashSet;
9use fs::Fs;
10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
11use gpui::WeakEntity;
12use gpui::{App, AsyncApp, Global, prelude::*};
13use http_client::HttpRequestExt;
14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
15use itertools::Itertools;
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use crate::copilot_responses as responses;
20use settings::watch_config_dir;
21
22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn token_url(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 let domain = Self::parse_domain(enterprise_uri);
33 format!("https://api.{}/copilot_internal/v2/token", domain)
34 } else {
35 "https://api.github.com/copilot_internal/v2/token".to_string()
36 }
37 }
38
39 pub fn oauth_domain(&self) -> String {
40 if let Some(enterprise_uri) = &self.enterprise_uri {
41 Self::parse_domain(enterprise_uri)
42 } else {
43 "github.com".to_string()
44 }
45 }
46
47 pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
48 format!("{}/chat/completions", endpoint)
49 }
50
51 pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
52 format!("{}/responses", endpoint)
53 }
54
55 pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
56 format!("{}/models", endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86}
87
88#[derive(Deserialize)]
89struct ModelSchema {
90 #[serde(deserialize_with = "deserialize_models_skip_errors")]
91 data: Vec<Model>,
92}
93
94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
95where
96 D: serde::Deserializer<'de>,
97{
98 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
99 let models = raw_values
100 .into_iter()
101 .filter_map(|value| match serde_json::from_value::<Model>(value) {
102 Ok(model) => Some(model),
103 Err(err) => {
104 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
105 None
106 }
107 })
108 .collect();
109
110 Ok(models)
111}
112
113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
114pub struct Model {
115 billing: ModelBilling,
116 capabilities: ModelCapabilities,
117 id: String,
118 name: String,
119 policy: Option<ModelPolicy>,
120 vendor: ModelVendor,
121 is_chat_default: bool,
122 // The model with this value true is selected by VSCode copilot if a premium request limit is
123 // reached. Zed does not currently implement this behaviour
124 is_chat_fallback: bool,
125 model_picker_enabled: bool,
126 #[serde(default)]
127 supported_endpoints: Vec<ModelSupportedEndpoint>,
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
131struct ModelBilling {
132 is_premium: bool,
133 multiplier: f64,
134 // List of plans a model is restricted to
135 // Field is not present if a model is available for all plans
136 #[serde(default)]
137 restricted_to: Option<Vec<String>>,
138}
139
140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
141struct ModelCapabilities {
142 family: String,
143 #[serde(default)]
144 limits: ModelLimits,
145 supports: ModelSupportedFeatures,
146 #[serde(rename = "type")]
147 model_type: String,
148 #[serde(default)]
149 tokenizer: Option<String>,
150}
151
152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
153struct ModelLimits {
154 #[serde(default)]
155 max_context_window_tokens: usize,
156 #[serde(default)]
157 max_output_tokens: usize,
158 #[serde(default)]
159 max_prompt_tokens: u64,
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
163struct ModelPolicy {
164 state: String,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelSupportedFeatures {
169 #[serde(default)]
170 streaming: bool,
171 #[serde(default)]
172 tool_calls: bool,
173 #[serde(default)]
174 parallel_tool_calls: bool,
175 #[serde(default)]
176 vision: bool,
177}
178
179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
180pub enum ModelVendor {
181 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
182 #[serde(alias = "Azure OpenAI")]
183 OpenAI,
184 Google,
185 Anthropic,
186 #[serde(rename = "xAI")]
187 XAI,
188 /// Unknown vendor that we don't explicitly support yet
189 #[serde(other)]
190 Unknown,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
194#[serde(tag = "type")]
195pub enum ChatMessagePart {
196 #[serde(rename = "text")]
197 Text { text: String },
198 #[serde(rename = "image_url")]
199 Image { image_url: ImageUrl },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
203pub struct ImageUrl {
204 pub url: String,
205}
206
207impl Model {
208 pub fn uses_streaming(&self) -> bool {
209 self.capabilities.supports.streaming
210 }
211
212 pub fn id(&self) -> &str {
213 self.id.as_str()
214 }
215
216 pub fn display_name(&self) -> &str {
217 self.name.as_str()
218 }
219
220 pub fn max_token_count(&self) -> u64 {
221 self.capabilities.limits.max_prompt_tokens
222 }
223
224 pub fn supports_tools(&self) -> bool {
225 self.capabilities.supports.tool_calls
226 }
227
228 pub fn vendor(&self) -> ModelVendor {
229 self.vendor
230 }
231
232 pub fn supports_vision(&self) -> bool {
233 self.capabilities.supports.vision
234 }
235
236 pub fn supports_parallel_tool_calls(&self) -> bool {
237 self.capabilities.supports.parallel_tool_calls
238 }
239
240 pub fn tokenizer(&self) -> Option<&str> {
241 self.capabilities.tokenizer.as_deref()
242 }
243
244 pub fn supports_response(&self) -> bool {
245 self.supported_endpoints.len() > 0
246 && !self
247 .supported_endpoints
248 .contains(&ModelSupportedEndpoint::ChatCompletions)
249 && self
250 .supported_endpoints
251 .contains(&ModelSupportedEndpoint::Responses)
252 }
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct Request {
257 pub intent: bool,
258 pub n: usize,
259 pub stream: bool,
260 pub temperature: f32,
261 pub model: String,
262 pub messages: Vec<ChatMessage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<Tool>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub tool_choice: Option<ToolChoice>,
267}
268
269#[derive(Serialize, Deserialize)]
270pub struct Function {
271 pub name: String,
272 pub description: String,
273 pub parameters: serde_json::Value,
274}
275
276#[derive(Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum Tool {
279 Function { function: Function },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolChoice {
285 Auto,
286 Any,
287 None,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291#[serde(tag = "role", rename_all = "lowercase")]
292pub enum ChatMessage {
293 Assistant {
294 content: ChatMessageContent,
295 #[serde(default, skip_serializing_if = "Vec::is_empty")]
296 tool_calls: Vec<ToolCall>,
297 },
298 User {
299 content: ChatMessageContent,
300 },
301 System {
302 content: String,
303 },
304 Tool {
305 content: ChatMessageContent,
306 tool_call_id: String,
307 },
308}
309
310#[derive(Debug, Serialize, Deserialize)]
311#[serde(untagged)]
312pub enum ChatMessageContent {
313 Plain(String),
314 Multipart(Vec<ChatMessagePart>),
315}
316
317impl ChatMessageContent {
318 pub fn empty() -> Self {
319 ChatMessageContent::Multipart(vec![])
320 }
321}
322
323impl From<Vec<ChatMessagePart>> for ChatMessageContent {
324 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
325 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
326 ChatMessageContent::Plain(std::mem::take(text))
327 } else {
328 ChatMessageContent::Multipart(parts)
329 }
330 }
331}
332
333impl From<String> for ChatMessageContent {
334 fn from(text: String) -> Self {
335 ChatMessageContent::Plain(text)
336 }
337}
338
339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
340pub struct ToolCall {
341 pub id: String,
342 #[serde(flatten)]
343 pub content: ToolCallContent,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347#[serde(tag = "type", rename_all = "lowercase")]
348pub enum ToolCallContent {
349 Function { function: FunctionContent },
350}
351
352#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
353pub struct FunctionContent {
354 pub name: String,
355 pub arguments: String,
356}
357
358#[derive(Deserialize, Debug)]
359#[serde(tag = "type", rename_all = "snake_case")]
360pub struct ResponseEvent {
361 pub choices: Vec<ResponseChoice>,
362 pub id: String,
363 pub usage: Option<Usage>,
364}
365
366#[derive(Deserialize, Debug)]
367pub struct Usage {
368 pub completion_tokens: u64,
369 pub prompt_tokens: u64,
370 pub total_tokens: u64,
371}
372
373#[derive(Debug, Deserialize)]
374pub struct ResponseChoice {
375 pub index: Option<usize>,
376 pub finish_reason: Option<String>,
377 pub delta: Option<ResponseDelta>,
378 pub message: Option<ResponseDelta>,
379}
380
381#[derive(Debug, Deserialize)]
382pub struct ResponseDelta {
383 pub content: Option<String>,
384 pub role: Option<Role>,
385 #[serde(default)]
386 pub tool_calls: Vec<ToolCallChunk>,
387}
388#[derive(Deserialize, Debug, Eq, PartialEq)]
389pub struct ToolCallChunk {
390 pub index: Option<usize>,
391 pub id: Option<String>,
392 pub function: Option<FunctionChunk>,
393}
394
395#[derive(Deserialize, Debug, Eq, PartialEq)]
396pub struct FunctionChunk {
397 pub name: Option<String>,
398 pub arguments: Option<String>,
399}
400
401#[derive(Deserialize)]
402struct ApiTokenResponse {
403 token: String,
404 expires_at: i64,
405 endpoints: ApiTokenResponseEndpoints,
406}
407
408#[derive(Deserialize)]
409struct ApiTokenResponseEndpoints {
410 api: String,
411}
412
413#[derive(Clone)]
414struct ApiToken {
415 api_key: String,
416 expires_at: DateTime<chrono::Utc>,
417 api_endpoint: String,
418}
419
420impl ApiToken {
421 pub fn remaining_seconds(&self) -> i64 {
422 self.expires_at
423 .timestamp()
424 .saturating_sub(chrono::Utc::now().timestamp())
425 }
426}
427
428impl TryFrom<ApiTokenResponse> for ApiToken {
429 type Error = anyhow::Error;
430
431 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
432 let expires_at =
433 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
434
435 Ok(Self {
436 api_key: response.token,
437 expires_at,
438 api_endpoint: response.endpoints.api,
439 })
440 }
441}
442
443struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
444
445impl Global for GlobalCopilotChat {}
446
447pub struct CopilotChat {
448 oauth_token: Option<String>,
449 api_token: Option<ApiToken>,
450 configuration: CopilotChatConfiguration,
451 models: Option<Vec<Model>>,
452 client: Arc<dyn HttpClient>,
453}
454
455pub fn init(
456 fs: Arc<dyn Fs>,
457 client: Arc<dyn HttpClient>,
458 configuration: CopilotChatConfiguration,
459 cx: &mut App,
460) {
461 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
462 cx.set_global(GlobalCopilotChat(copilot_chat));
463}
464
465pub fn copilot_chat_config_dir() -> &'static PathBuf {
466 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
467
468 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
469 let config_dir = if cfg!(target_os = "windows") {
470 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
471 } else {
472 std::env::var("XDG_CONFIG_HOME")
473 .map(PathBuf::from)
474 .unwrap_or_else(|_| home_dir().join(".config"))
475 };
476
477 config_dir.join("github-copilot")
478 })
479}
480
481fn copilot_chat_config_paths() -> [PathBuf; 2] {
482 let base_dir = copilot_chat_config_dir();
483 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
484}
485
486impl CopilotChat {
487 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
488 cx.try_global::<GlobalCopilotChat>()
489 .map(|model| model.0.clone())
490 }
491
492 fn new(
493 fs: Arc<dyn Fs>,
494 client: Arc<dyn HttpClient>,
495 configuration: CopilotChatConfiguration,
496 cx: &mut Context<Self>,
497 ) -> Self {
498 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
499 let dir_path = copilot_chat_config_dir();
500
501 cx.spawn(async move |this, cx| {
502 let mut parent_watch_rx = watch_config_dir(
503 cx.background_executor(),
504 fs.clone(),
505 dir_path.clone(),
506 config_paths,
507 );
508 while let Some(contents) = parent_watch_rx.next().await {
509 let oauth_domain =
510 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
511 let oauth_token = extract_oauth_token(contents, &oauth_domain);
512
513 this.update(cx, |this, cx| {
514 this.oauth_token = oauth_token.clone();
515 cx.notify();
516 })?;
517
518 if oauth_token.is_some() {
519 Self::update_models(&this, cx).await?;
520 }
521 }
522 anyhow::Ok(())
523 })
524 .detach_and_log_err(cx);
525
526 let this = Self {
527 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
528 api_token: None,
529 models: None,
530 configuration,
531 client,
532 };
533
534 if this.oauth_token.is_some() {
535 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
536 .detach_and_log_err(cx);
537 }
538
539 this
540 }
541
542 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
543 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
544 (
545 this.oauth_token.clone(),
546 this.client.clone(),
547 this.configuration.clone(),
548 )
549 })?;
550
551 let oauth_token = oauth_token
552 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
553
554 let token_url = configuration.token_url();
555 let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
556
557 let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
558 let models =
559 get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
560
561 this.update(cx, |this, cx| {
562 this.api_token = Some(api_token);
563 this.models = Some(models);
564 cx.notify();
565 })?;
566 anyhow::Ok(())
567 }
568
569 pub fn is_authenticated(&self) -> bool {
570 self.oauth_token.is_some()
571 }
572
573 pub fn models(&self) -> Option<&[Model]> {
574 self.models.as_deref()
575 }
576
577 pub async fn stream_completion(
578 request: Request,
579 is_user_initiated: bool,
580 mut cx: AsyncApp,
581 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
582 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
583
584 let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
585 stream_completion(
586 client.clone(),
587 token.api_key,
588 api_url.into(),
589 request,
590 is_user_initiated,
591 )
592 .await
593 }
594
595 pub async fn stream_response(
596 request: responses::Request,
597 is_user_initiated: bool,
598 mut cx: AsyncApp,
599 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
600 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
601
602 let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
603 responses::stream_response(
604 client.clone(),
605 token.api_key,
606 api_url,
607 request,
608 is_user_initiated,
609 )
610 .await
611 }
612
613 async fn get_auth_details(
614 cx: &mut AsyncApp,
615 ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
616 let this = cx
617 .update(|cx| Self::global(cx))
618 .ok()
619 .flatten()
620 .context("Copilot chat is not enabled")?;
621
622 let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
623 (
624 this.oauth_token.clone(),
625 this.api_token.clone(),
626 this.client.clone(),
627 this.configuration.clone(),
628 )
629 })?;
630
631 let oauth_token = oauth_token.context("No OAuth token available")?;
632
633 let token = match api_token {
634 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
635 _ => {
636 let token_url = configuration.token_url();
637 let token =
638 request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
639 this.update(cx, |this, cx| {
640 this.api_token = Some(token.clone());
641 cx.notify();
642 })?;
643 token
644 }
645 };
646
647 Ok((client, token, configuration))
648 }
649
650 pub fn set_configuration(
651 &mut self,
652 configuration: CopilotChatConfiguration,
653 cx: &mut Context<Self>,
654 ) {
655 let same_configuration = self.configuration == configuration;
656 self.configuration = configuration;
657 if !same_configuration {
658 self.api_token = None;
659 cx.spawn(async move |this, cx| {
660 Self::update_models(&this, cx).await?;
661 Ok::<_, anyhow::Error>(())
662 })
663 .detach();
664 }
665 }
666}
667
668async fn get_models(
669 models_url: Arc<str>,
670 api_token: String,
671 client: Arc<dyn HttpClient>,
672) -> Result<Vec<Model>> {
673 let all_models = request_models(models_url, api_token, client).await?;
674
675 let mut models: Vec<Model> = all_models
676 .into_iter()
677 .filter(|model| {
678 model.model_picker_enabled
679 && model.capabilities.model_type.as_str() == "chat"
680 && model
681 .policy
682 .as_ref()
683 .is_none_or(|policy| policy.state == "enabled")
684 })
685 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
686 .collect();
687
688 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
689 let default_model = models.remove(default_model_position);
690 models.insert(0, default_model);
691 }
692
693 Ok(models)
694}
695
696async fn request_models(
697 models_url: Arc<str>,
698 api_token: String,
699 client: Arc<dyn HttpClient>,
700) -> Result<Vec<Model>> {
701 let request_builder = HttpRequest::builder()
702 .method(Method::GET)
703 .uri(models_url.as_ref())
704 .header("Authorization", format!("Bearer {}", api_token))
705 .header("Content-Type", "application/json")
706 .header("Copilot-Integration-Id", "vscode-chat")
707 .header("Editor-Version", "vscode/1.103.2")
708 .header("x-github-api-version", "2025-05-01");
709
710 let request = request_builder.body(AsyncBody::empty())?;
711
712 let mut response = client.send(request).await?;
713
714 anyhow::ensure!(
715 response.status().is_success(),
716 "Failed to request models: {}",
717 response.status()
718 );
719 let mut body = Vec::new();
720 response.body_mut().read_to_end(&mut body).await?;
721
722 let body_str = std::str::from_utf8(&body)?;
723
724 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
725
726 Ok(models)
727}
728
729async fn request_api_token(
730 oauth_token: &str,
731 auth_url: Arc<str>,
732 client: Arc<dyn HttpClient>,
733) -> Result<ApiToken> {
734 let request_builder = HttpRequest::builder()
735 .method(Method::GET)
736 .uri(auth_url.as_ref())
737 .header("Authorization", format!("token {}", oauth_token))
738 .header("Accept", "application/json");
739
740 let request = request_builder.body(AsyncBody::empty())?;
741
742 let mut response = client.send(request).await?;
743
744 if response.status().is_success() {
745 let mut body = Vec::new();
746 response.body_mut().read_to_end(&mut body).await?;
747
748 let body_str = std::str::from_utf8(&body)?;
749
750 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
751 ApiToken::try_from(parsed)
752 } else {
753 let mut body = Vec::new();
754 response.body_mut().read_to_end(&mut body).await?;
755
756 let body_str = std::str::from_utf8(&body)?;
757 anyhow::bail!("Failed to request API token: {body_str}");
758 }
759}
760
761fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
762 serde_json::from_str::<serde_json::Value>(&contents)
763 .map(|v| {
764 v.as_object().and_then(|obj| {
765 obj.iter().find_map(|(key, value)| {
766 if key.starts_with(domain) {
767 value["oauth_token"].as_str().map(|v| v.to_string())
768 } else {
769 None
770 }
771 })
772 })
773 })
774 .ok()
775 .flatten()
776}
777
778async fn stream_completion(
779 client: Arc<dyn HttpClient>,
780 api_key: String,
781 completion_url: Arc<str>,
782 request: Request,
783 is_user_initiated: bool,
784) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
785 let is_vision_request = request.messages.iter().any(|message| match message {
786 ChatMessage::User { content }
787 | ChatMessage::Assistant { content, .. }
788 | ChatMessage::Tool { content, .. } => {
789 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
790 }
791 _ => false,
792 });
793
794 let request_initiator = if is_user_initiated { "user" } else { "agent" };
795
796 let request_builder = HttpRequest::builder()
797 .method(Method::POST)
798 .uri(completion_url.as_ref())
799 .header(
800 "Editor-Version",
801 format!(
802 "Zed/{}",
803 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
804 ),
805 )
806 .header("Authorization", format!("Bearer {}", api_key))
807 .header("Content-Type", "application/json")
808 .header("Copilot-Integration-Id", "vscode-chat")
809 .header("X-Initiator", request_initiator)
810 .when(is_vision_request, |builder| {
811 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
812 });
813
814 let is_streaming = request.stream;
815
816 let json = serde_json::to_string(&request)?;
817 let request = request_builder.body(AsyncBody::from(json))?;
818 let mut response = client.send(request).await?;
819
820 if !response.status().is_success() {
821 let mut body = Vec::new();
822 response.body_mut().read_to_end(&mut body).await?;
823 let body_str = std::str::from_utf8(&body)?;
824 anyhow::bail!(
825 "Failed to connect to API: {} {}",
826 response.status(),
827 body_str
828 );
829 }
830
831 if is_streaming {
832 let reader = BufReader::new(response.into_body());
833 Ok(reader
834 .lines()
835 .filter_map(|line| async move {
836 match line {
837 Ok(line) => {
838 let line = line.strip_prefix("data: ")?;
839 if line.starts_with("[DONE]") {
840 return None;
841 }
842
843 match serde_json::from_str::<ResponseEvent>(line) {
844 Ok(response) => {
845 if response.choices.is_empty() {
846 None
847 } else {
848 Some(Ok(response))
849 }
850 }
851 Err(error) => Some(Err(anyhow!(error))),
852 }
853 }
854 Err(error) => Some(Err(anyhow!(error))),
855 }
856 })
857 .boxed())
858 } else {
859 let mut body = Vec::new();
860 response.body_mut().read_to_end(&mut body).await?;
861 let body_str = std::str::from_utf8(&body)?;
862 let response: ResponseEvent = serde_json::from_str(body_str)?;
863
864 Ok(futures::stream::once(async move { Ok(response) }).boxed())
865 }
866}
867
868#[cfg(test)]
869mod tests {
870 use super::*;
871
872 #[test]
873 fn test_resilient_model_schema_deserialize() {
874 let json = r#"{
875 "data": [
876 {
877 "billing": {
878 "is_premium": false,
879 "multiplier": 0
880 },
881 "capabilities": {
882 "family": "gpt-4",
883 "limits": {
884 "max_context_window_tokens": 32768,
885 "max_output_tokens": 4096,
886 "max_prompt_tokens": 32768
887 },
888 "object": "model_capabilities",
889 "supports": { "streaming": true, "tool_calls": true },
890 "tokenizer": "cl100k_base",
891 "type": "chat"
892 },
893 "id": "gpt-4",
894 "is_chat_default": false,
895 "is_chat_fallback": false,
896 "model_picker_enabled": false,
897 "name": "GPT 4",
898 "object": "model",
899 "preview": false,
900 "vendor": "Azure OpenAI",
901 "version": "gpt-4-0613"
902 },
903 {
904 "some-unknown-field": 123
905 },
906 {
907 "billing": {
908 "is_premium": true,
909 "multiplier": 1,
910 "restricted_to": [
911 "pro",
912 "pro_plus",
913 "business",
914 "enterprise"
915 ]
916 },
917 "capabilities": {
918 "family": "claude-3.7-sonnet",
919 "limits": {
920 "max_context_window_tokens": 200000,
921 "max_output_tokens": 16384,
922 "max_prompt_tokens": 90000,
923 "vision": {
924 "max_prompt_image_size": 3145728,
925 "max_prompt_images": 1,
926 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
927 }
928 },
929 "object": "model_capabilities",
930 "supports": {
931 "parallel_tool_calls": true,
932 "streaming": true,
933 "tool_calls": true,
934 "vision": true
935 },
936 "tokenizer": "o200k_base",
937 "type": "chat"
938 },
939 "id": "claude-3.7-sonnet",
940 "is_chat_default": false,
941 "is_chat_fallback": false,
942 "model_picker_enabled": true,
943 "name": "Claude 3.7 Sonnet",
944 "object": "model",
945 "policy": {
946 "state": "enabled",
947 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
948 },
949 "preview": false,
950 "vendor": "Anthropic",
951 "version": "claude-3.7-sonnet"
952 }
953 ],
954 "object": "list"
955 }"#;
956
957 let schema: ModelSchema = serde_json::from_str(json).unwrap();
958
959 assert_eq!(schema.data.len(), 2);
960 assert_eq!(schema.data[0].id, "gpt-4");
961 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
962 }
963
964 #[test]
965 fn test_unknown_vendor_resilience() {
966 let json = r#"{
967 "data": [
968 {
969 "billing": {
970 "is_premium": false,
971 "multiplier": 1
972 },
973 "capabilities": {
974 "family": "future-model",
975 "limits": {
976 "max_context_window_tokens": 128000,
977 "max_output_tokens": 8192,
978 "max_prompt_tokens": 120000
979 },
980 "object": "model_capabilities",
981 "supports": { "streaming": true, "tool_calls": true },
982 "type": "chat"
983 },
984 "id": "future-model-v1",
985 "is_chat_default": false,
986 "is_chat_fallback": false,
987 "model_picker_enabled": true,
988 "name": "Future Model v1",
989 "object": "model",
990 "preview": false,
991 "vendor": "SomeNewVendor",
992 "version": "v1.0"
993 }
994 ],
995 "object": "list"
996 }"#;
997
998 let schema: ModelSchema = serde_json::from_str(json).unwrap();
999
1000 assert_eq!(schema.data.len(), 1);
1001 assert_eq!(schema.data[0].id, "future-model-v1");
1002 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1003 }
1004}