1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::Context as _;
6use anyhow::{Result, anyhow};
7use chrono::DateTime;
8use collections::HashSet;
9use fs::Fs;
10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
11use gpui::WeakEntity;
12use gpui::{App, AsyncApp, Global, prelude::*};
13use http_client::HttpRequestExt;
14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
15use itertools::Itertools;
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use crate::copilot_responses as responses;
20use settings::watch_config_dir;
21
22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn token_url(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 let domain = Self::parse_domain(enterprise_uri);
33 format!("https://api.{}/copilot_internal/v2/token", domain)
34 } else {
35 "https://api.github.com/copilot_internal/v2/token".to_string()
36 }
37 }
38
39 pub fn oauth_domain(&self) -> String {
40 if let Some(enterprise_uri) = &self.enterprise_uri {
41 Self::parse_domain(enterprise_uri)
42 } else {
43 "github.com".to_string()
44 }
45 }
46
47 pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
48 format!("{}/chat/completions", endpoint)
49 }
50
51 pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
52 format!("{}/responses", endpoint)
53 }
54
55 pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
56 format!("{}/models", endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86}
87
88#[derive(Deserialize)]
89struct ModelSchema {
90 #[serde(deserialize_with = "deserialize_models_skip_errors")]
91 data: Vec<Model>,
92}
93
94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
95where
96 D: serde::Deserializer<'de>,
97{
98 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
99 let models = raw_values
100 .into_iter()
101 .filter_map(|value| match serde_json::from_value::<Model>(value) {
102 Ok(model) => Some(model),
103 Err(err) => {
104 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
105 None
106 }
107 })
108 .collect();
109
110 Ok(models)
111}
112
113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
114pub struct Model {
115 billing: ModelBilling,
116 capabilities: ModelCapabilities,
117 id: String,
118 name: String,
119 policy: Option<ModelPolicy>,
120 vendor: ModelVendor,
121 is_chat_default: bool,
122 // The model with this value true is selected by VSCode copilot if a premium request limit is
123 // reached. Zed does not currently implement this behaviour
124 is_chat_fallback: bool,
125 model_picker_enabled: bool,
126 #[serde(default)]
127 supported_endpoints: Vec<ModelSupportedEndpoint>,
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
131struct ModelBilling {
132 is_premium: bool,
133 multiplier: f64,
134 // List of plans a model is restricted to
135 // Field is not present if a model is available for all plans
136 #[serde(default)]
137 restricted_to: Option<Vec<String>>,
138}
139
140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
141struct ModelCapabilities {
142 family: String,
143 #[serde(default)]
144 limits: ModelLimits,
145 supports: ModelSupportedFeatures,
146 #[serde(rename = "type")]
147 model_type: String,
148 #[serde(default)]
149 tokenizer: Option<String>,
150}
151
152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
153struct ModelLimits {
154 #[serde(default)]
155 max_context_window_tokens: usize,
156 #[serde(default)]
157 max_output_tokens: usize,
158 #[serde(default)]
159 max_prompt_tokens: u64,
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
163struct ModelPolicy {
164 state: String,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelSupportedFeatures {
169 #[serde(default)]
170 streaming: bool,
171 #[serde(default)]
172 tool_calls: bool,
173 #[serde(default)]
174 parallel_tool_calls: bool,
175 #[serde(default)]
176 vision: bool,
177}
178
179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
180pub enum ModelVendor {
181 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
182 #[serde(alias = "Azure OpenAI")]
183 OpenAI,
184 Google,
185 Anthropic,
186 #[serde(rename = "xAI")]
187 XAI,
188 /// Unknown vendor that we don't explicitly support yet
189 #[serde(other)]
190 Unknown,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
194#[serde(tag = "type")]
195pub enum ChatMessagePart {
196 #[serde(rename = "text")]
197 Text { text: String },
198 #[serde(rename = "image_url")]
199 Image { image_url: ImageUrl },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
203pub struct ImageUrl {
204 pub url: String,
205}
206
207impl Model {
208 pub fn uses_streaming(&self) -> bool {
209 self.capabilities.supports.streaming
210 }
211
212 pub fn id(&self) -> &str {
213 self.id.as_str()
214 }
215
216 pub fn display_name(&self) -> &str {
217 self.name.as_str()
218 }
219
220 pub fn max_token_count(&self) -> u64 {
221 self.capabilities.limits.max_prompt_tokens
222 }
223
224 pub fn supports_tools(&self) -> bool {
225 self.capabilities.supports.tool_calls
226 }
227
228 pub fn vendor(&self) -> ModelVendor {
229 self.vendor
230 }
231
232 pub fn supports_vision(&self) -> bool {
233 self.capabilities.supports.vision
234 }
235
236 pub fn supports_parallel_tool_calls(&self) -> bool {
237 self.capabilities.supports.parallel_tool_calls
238 }
239
240 pub fn tokenizer(&self) -> Option<&str> {
241 self.capabilities.tokenizer.as_deref()
242 }
243
244 pub fn supports_response(&self) -> bool {
245 self.supported_endpoints.len() > 0
246 && !self
247 .supported_endpoints
248 .contains(&ModelSupportedEndpoint::ChatCompletions)
249 && self
250 .supported_endpoints
251 .contains(&ModelSupportedEndpoint::Responses)
252 }
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct Request {
257 pub intent: bool,
258 pub n: usize,
259 pub stream: bool,
260 pub temperature: f32,
261 pub model: String,
262 pub messages: Vec<ChatMessage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<Tool>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub tool_choice: Option<ToolChoice>,
267}
268
269#[derive(Serialize, Deserialize)]
270pub struct Function {
271 pub name: String,
272 pub description: String,
273 pub parameters: serde_json::Value,
274}
275
276#[derive(Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum Tool {
279 Function { function: Function },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolChoice {
285 Auto,
286 Any,
287 None,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291#[serde(tag = "role", rename_all = "lowercase")]
292pub enum ChatMessage {
293 Assistant {
294 content: ChatMessageContent,
295 #[serde(default, skip_serializing_if = "Vec::is_empty")]
296 tool_calls: Vec<ToolCall>,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 reasoning_opaque: Option<String>,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 reasoning_text: Option<String>,
301 },
302 User {
303 content: ChatMessageContent,
304 },
305 System {
306 content: String,
307 },
308 Tool {
309 content: ChatMessageContent,
310 tool_call_id: String,
311 },
312}
313
314#[derive(Debug, Serialize, Deserialize)]
315#[serde(untagged)]
316pub enum ChatMessageContent {
317 Plain(String),
318 Multipart(Vec<ChatMessagePart>),
319}
320
321impl ChatMessageContent {
322 pub fn empty() -> Self {
323 ChatMessageContent::Multipart(vec![])
324 }
325}
326
327impl From<Vec<ChatMessagePart>> for ChatMessageContent {
328 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
329 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
330 ChatMessageContent::Plain(std::mem::take(text))
331 } else {
332 ChatMessageContent::Multipart(parts)
333 }
334 }
335}
336
337impl From<String> for ChatMessageContent {
338 fn from(text: String) -> Self {
339 ChatMessageContent::Plain(text)
340 }
341}
342
343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
344pub struct ToolCall {
345 pub id: String,
346 #[serde(flatten)]
347 pub content: ToolCallContent,
348}
349
350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
351#[serde(tag = "type", rename_all = "lowercase")]
352pub enum ToolCallContent {
353 Function { function: FunctionContent },
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct FunctionContent {
358 pub name: String,
359 pub arguments: String,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub thought_signature: Option<String>,
362}
363
364#[derive(Deserialize, Debug)]
365#[serde(tag = "type", rename_all = "snake_case")]
366pub struct ResponseEvent {
367 pub choices: Vec<ResponseChoice>,
368 pub id: String,
369 pub usage: Option<Usage>,
370}
371
372#[derive(Deserialize, Debug)]
373pub struct Usage {
374 pub completion_tokens: u64,
375 pub prompt_tokens: u64,
376 pub total_tokens: u64,
377}
378
379#[derive(Debug, Deserialize)]
380pub struct ResponseChoice {
381 pub index: Option<usize>,
382 pub finish_reason: Option<String>,
383 pub delta: Option<ResponseDelta>,
384 pub message: Option<ResponseDelta>,
385}
386
387#[derive(Debug, Deserialize)]
388pub struct ResponseDelta {
389 pub content: Option<String>,
390 pub role: Option<Role>,
391 #[serde(default)]
392 pub tool_calls: Vec<ToolCallChunk>,
393 pub reasoning_opaque: Option<String>,
394 pub reasoning_text: Option<String>,
395}
396#[derive(Deserialize, Debug, Eq, PartialEq)]
397pub struct ToolCallChunk {
398 pub index: Option<usize>,
399 pub id: Option<String>,
400 pub function: Option<FunctionChunk>,
401}
402
403#[derive(Deserialize, Debug, Eq, PartialEq)]
404pub struct FunctionChunk {
405 pub name: Option<String>,
406 pub arguments: Option<String>,
407 pub thought_signature: Option<String>,
408}
409
410#[derive(Deserialize)]
411struct ApiTokenResponse {
412 token: String,
413 expires_at: i64,
414 endpoints: ApiTokenResponseEndpoints,
415}
416
417#[derive(Deserialize)]
418struct ApiTokenResponseEndpoints {
419 api: String,
420}
421
422#[derive(Clone)]
423struct ApiToken {
424 api_key: String,
425 expires_at: DateTime<chrono::Utc>,
426 api_endpoint: String,
427}
428
429impl ApiToken {
430 pub fn remaining_seconds(&self) -> i64 {
431 self.expires_at
432 .timestamp()
433 .saturating_sub(chrono::Utc::now().timestamp())
434 }
435}
436
437impl TryFrom<ApiTokenResponse> for ApiToken {
438 type Error = anyhow::Error;
439
440 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
441 let expires_at =
442 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
443
444 Ok(Self {
445 api_key: response.token,
446 expires_at,
447 api_endpoint: response.endpoints.api,
448 })
449 }
450}
451
452struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
453
454impl Global for GlobalCopilotChat {}
455
456pub struct CopilotChat {
457 oauth_token: Option<String>,
458 api_token: Option<ApiToken>,
459 configuration: CopilotChatConfiguration,
460 models: Option<Vec<Model>>,
461 client: Arc<dyn HttpClient>,
462}
463
464pub fn init(
465 fs: Arc<dyn Fs>,
466 client: Arc<dyn HttpClient>,
467 configuration: CopilotChatConfiguration,
468 cx: &mut App,
469) {
470 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
471 cx.set_global(GlobalCopilotChat(copilot_chat));
472}
473
474pub fn copilot_chat_config_dir() -> &'static PathBuf {
475 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
476
477 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
478 let config_dir = if cfg!(target_os = "windows") {
479 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
480 } else {
481 std::env::var("XDG_CONFIG_HOME")
482 .map(PathBuf::from)
483 .unwrap_or_else(|_| home_dir().join(".config"))
484 };
485
486 config_dir.join("github-copilot")
487 })
488}
489
490fn copilot_chat_config_paths() -> [PathBuf; 2] {
491 let base_dir = copilot_chat_config_dir();
492 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
493}
494
495impl CopilotChat {
496 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
497 cx.try_global::<GlobalCopilotChat>()
498 .map(|model| model.0.clone())
499 }
500
501 fn new(
502 fs: Arc<dyn Fs>,
503 client: Arc<dyn HttpClient>,
504 configuration: CopilotChatConfiguration,
505 cx: &mut Context<Self>,
506 ) -> Self {
507 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
508 let dir_path = copilot_chat_config_dir();
509
510 cx.spawn(async move |this, cx| {
511 let mut parent_watch_rx = watch_config_dir(
512 cx.background_executor(),
513 fs.clone(),
514 dir_path.clone(),
515 config_paths,
516 );
517 while let Some(contents) = parent_watch_rx.next().await {
518 let oauth_domain =
519 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
520 let oauth_token = extract_oauth_token(contents, &oauth_domain);
521
522 this.update(cx, |this, cx| {
523 this.oauth_token = oauth_token.clone();
524 cx.notify();
525 })?;
526
527 if oauth_token.is_some() {
528 Self::update_models(&this, cx).await?;
529 }
530 }
531 anyhow::Ok(())
532 })
533 .detach_and_log_err(cx);
534
535 let this = Self {
536 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
537 api_token: None,
538 models: None,
539 configuration,
540 client,
541 };
542
543 if this.oauth_token.is_some() {
544 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
545 .detach_and_log_err(cx);
546 }
547
548 this
549 }
550
551 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
552 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
553 (
554 this.oauth_token.clone(),
555 this.client.clone(),
556 this.configuration.clone(),
557 )
558 })?;
559
560 let oauth_token = oauth_token
561 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
562
563 let token_url = configuration.token_url();
564 let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
565
566 let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
567 let models =
568 get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
569
570 this.update(cx, |this, cx| {
571 this.api_token = Some(api_token);
572 this.models = Some(models);
573 cx.notify();
574 })?;
575 anyhow::Ok(())
576 }
577
578 pub fn is_authenticated(&self) -> bool {
579 self.oauth_token.is_some()
580 }
581
582 pub fn models(&self) -> Option<&[Model]> {
583 self.models.as_deref()
584 }
585
586 pub async fn stream_completion(
587 request: Request,
588 is_user_initiated: bool,
589 mut cx: AsyncApp,
590 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
591 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
592
593 let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
594 stream_completion(
595 client.clone(),
596 token.api_key,
597 api_url.into(),
598 request,
599 is_user_initiated,
600 )
601 .await
602 }
603
604 pub async fn stream_response(
605 request: responses::Request,
606 is_user_initiated: bool,
607 mut cx: AsyncApp,
608 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
609 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
610
611 let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
612 responses::stream_response(
613 client.clone(),
614 token.api_key,
615 api_url,
616 request,
617 is_user_initiated,
618 )
619 .await
620 }
621
622 async fn get_auth_details(
623 cx: &mut AsyncApp,
624 ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
625 let this = cx
626 .update(|cx| Self::global(cx))
627 .ok()
628 .flatten()
629 .context("Copilot chat is not enabled")?;
630
631 let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
632 (
633 this.oauth_token.clone(),
634 this.api_token.clone(),
635 this.client.clone(),
636 this.configuration.clone(),
637 )
638 })?;
639
640 let oauth_token = oauth_token.context("No OAuth token available")?;
641
642 let token = match api_token {
643 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
644 _ => {
645 let token_url = configuration.token_url();
646 let token =
647 request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
648 this.update(cx, |this, cx| {
649 this.api_token = Some(token.clone());
650 cx.notify();
651 })?;
652 token
653 }
654 };
655
656 Ok((client, token, configuration))
657 }
658
659 pub fn set_configuration(
660 &mut self,
661 configuration: CopilotChatConfiguration,
662 cx: &mut Context<Self>,
663 ) {
664 let same_configuration = self.configuration == configuration;
665 self.configuration = configuration;
666 if !same_configuration {
667 self.api_token = None;
668 cx.spawn(async move |this, cx| {
669 Self::update_models(&this, cx).await?;
670 Ok::<_, anyhow::Error>(())
671 })
672 .detach();
673 }
674 }
675}
676
677async fn get_models(
678 models_url: Arc<str>,
679 api_token: String,
680 client: Arc<dyn HttpClient>,
681) -> Result<Vec<Model>> {
682 let all_models = request_models(models_url, api_token, client).await?;
683
684 let mut models: Vec<Model> = all_models
685 .into_iter()
686 .filter(|model| {
687 model.model_picker_enabled
688 && model.capabilities.model_type.as_str() == "chat"
689 && model
690 .policy
691 .as_ref()
692 .is_none_or(|policy| policy.state == "enabled")
693 })
694 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
695 .collect();
696
697 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
698 let default_model = models.remove(default_model_position);
699 models.insert(0, default_model);
700 }
701
702 Ok(models)
703}
704
705async fn request_models(
706 models_url: Arc<str>,
707 api_token: String,
708 client: Arc<dyn HttpClient>,
709) -> Result<Vec<Model>> {
710 let request_builder = HttpRequest::builder()
711 .method(Method::GET)
712 .uri(models_url.as_ref())
713 .header("Authorization", format!("Bearer {}", api_token))
714 .header("Content-Type", "application/json")
715 .header("Copilot-Integration-Id", "vscode-chat")
716 .header("Editor-Version", "vscode/1.103.2")
717 .header("x-github-api-version", "2025-05-01");
718
719 let request = request_builder.body(AsyncBody::empty())?;
720
721 let mut response = client.send(request).await?;
722
723 anyhow::ensure!(
724 response.status().is_success(),
725 "Failed to request models: {}",
726 response.status()
727 );
728 let mut body = Vec::new();
729 response.body_mut().read_to_end(&mut body).await?;
730
731 let body_str = std::str::from_utf8(&body)?;
732
733 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
734
735 Ok(models)
736}
737
738async fn request_api_token(
739 oauth_token: &str,
740 auth_url: Arc<str>,
741 client: Arc<dyn HttpClient>,
742) -> Result<ApiToken> {
743 let request_builder = HttpRequest::builder()
744 .method(Method::GET)
745 .uri(auth_url.as_ref())
746 .header("Authorization", format!("token {}", oauth_token))
747 .header("Accept", "application/json");
748
749 let request = request_builder.body(AsyncBody::empty())?;
750
751 let mut response = client.send(request).await?;
752
753 if response.status().is_success() {
754 let mut body = Vec::new();
755 response.body_mut().read_to_end(&mut body).await?;
756
757 let body_str = std::str::from_utf8(&body)?;
758
759 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
760 ApiToken::try_from(parsed)
761 } else {
762 let mut body = Vec::new();
763 response.body_mut().read_to_end(&mut body).await?;
764
765 let body_str = std::str::from_utf8(&body)?;
766 anyhow::bail!("Failed to request API token: {body_str}");
767 }
768}
769
770fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
771 serde_json::from_str::<serde_json::Value>(&contents)
772 .map(|v| {
773 v.as_object().and_then(|obj| {
774 obj.iter().find_map(|(key, value)| {
775 if key.starts_with(domain) {
776 value["oauth_token"].as_str().map(|v| v.to_string())
777 } else {
778 None
779 }
780 })
781 })
782 })
783 .ok()
784 .flatten()
785}
786
787async fn stream_completion(
788 client: Arc<dyn HttpClient>,
789 api_key: String,
790 completion_url: Arc<str>,
791 request: Request,
792 is_user_initiated: bool,
793) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
794 let is_vision_request = request.messages.iter().any(|message| match message {
795 ChatMessage::User { content }
796 | ChatMessage::Assistant { content, .. }
797 | ChatMessage::Tool { content, .. } => {
798 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
799 }
800 _ => false,
801 });
802
803 let request_initiator = if is_user_initiated { "user" } else { "agent" };
804
805 let request_builder = HttpRequest::builder()
806 .method(Method::POST)
807 .uri(completion_url.as_ref())
808 .header(
809 "Editor-Version",
810 format!(
811 "Zed/{}",
812 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
813 ),
814 )
815 .header("Authorization", format!("Bearer {}", api_key))
816 .header("Content-Type", "application/json")
817 .header("Copilot-Integration-Id", "vscode-chat")
818 .header("X-Initiator", request_initiator)
819 .when(is_vision_request, |builder| {
820 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
821 });
822
823 let is_streaming = request.stream;
824
825 let json = serde_json::to_string(&request)?;
826 let request = request_builder.body(AsyncBody::from(json))?;
827 let mut response = client.send(request).await?;
828
829 if !response.status().is_success() {
830 let mut body = Vec::new();
831 response.body_mut().read_to_end(&mut body).await?;
832 let body_str = std::str::from_utf8(&body)?;
833 anyhow::bail!(
834 "Failed to connect to API: {} {}",
835 response.status(),
836 body_str
837 );
838 }
839
840 if is_streaming {
841 let reader = BufReader::new(response.into_body());
842 Ok(reader
843 .lines()
844 .filter_map(|line| async move {
845 match line {
846 Ok(line) => {
847 let line = line.strip_prefix("data: ")?;
848 if line.starts_with("[DONE]") {
849 return None;
850 }
851
852 match serde_json::from_str::<ResponseEvent>(line) {
853 Ok(response) => {
854 if response.choices.is_empty() {
855 None
856 } else {
857 Some(Ok(response))
858 }
859 }
860 Err(error) => Some(Err(anyhow!(error))),
861 }
862 }
863 Err(error) => Some(Err(anyhow!(error))),
864 }
865 })
866 .boxed())
867 } else {
868 let mut body = Vec::new();
869 response.body_mut().read_to_end(&mut body).await?;
870 let body_str = std::str::from_utf8(&body)?;
871 let response: ResponseEvent = serde_json::from_str(body_str)?;
872
873 Ok(futures::stream::once(async move { Ok(response) }).boxed())
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880
881 #[test]
882 fn test_resilient_model_schema_deserialize() {
883 let json = r#"{
884 "data": [
885 {
886 "billing": {
887 "is_premium": false,
888 "multiplier": 0
889 },
890 "capabilities": {
891 "family": "gpt-4",
892 "limits": {
893 "max_context_window_tokens": 32768,
894 "max_output_tokens": 4096,
895 "max_prompt_tokens": 32768
896 },
897 "object": "model_capabilities",
898 "supports": { "streaming": true, "tool_calls": true },
899 "tokenizer": "cl100k_base",
900 "type": "chat"
901 },
902 "id": "gpt-4",
903 "is_chat_default": false,
904 "is_chat_fallback": false,
905 "model_picker_enabled": false,
906 "name": "GPT 4",
907 "object": "model",
908 "preview": false,
909 "vendor": "Azure OpenAI",
910 "version": "gpt-4-0613"
911 },
912 {
913 "some-unknown-field": 123
914 },
915 {
916 "billing": {
917 "is_premium": true,
918 "multiplier": 1,
919 "restricted_to": [
920 "pro",
921 "pro_plus",
922 "business",
923 "enterprise"
924 ]
925 },
926 "capabilities": {
927 "family": "claude-3.7-sonnet",
928 "limits": {
929 "max_context_window_tokens": 200000,
930 "max_output_tokens": 16384,
931 "max_prompt_tokens": 90000,
932 "vision": {
933 "max_prompt_image_size": 3145728,
934 "max_prompt_images": 1,
935 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
936 }
937 },
938 "object": "model_capabilities",
939 "supports": {
940 "parallel_tool_calls": true,
941 "streaming": true,
942 "tool_calls": true,
943 "vision": true
944 },
945 "tokenizer": "o200k_base",
946 "type": "chat"
947 },
948 "id": "claude-3.7-sonnet",
949 "is_chat_default": false,
950 "is_chat_fallback": false,
951 "model_picker_enabled": true,
952 "name": "Claude 3.7 Sonnet",
953 "object": "model",
954 "policy": {
955 "state": "enabled",
956 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
957 },
958 "preview": false,
959 "vendor": "Anthropic",
960 "version": "claude-3.7-sonnet"
961 }
962 ],
963 "object": "list"
964 }"#;
965
966 let schema: ModelSchema = serde_json::from_str(json).unwrap();
967
968 assert_eq!(schema.data.len(), 2);
969 assert_eq!(schema.data[0].id, "gpt-4");
970 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
971 }
972
973 #[test]
974 fn test_unknown_vendor_resilience() {
975 let json = r#"{
976 "data": [
977 {
978 "billing": {
979 "is_premium": false,
980 "multiplier": 1
981 },
982 "capabilities": {
983 "family": "future-model",
984 "limits": {
985 "max_context_window_tokens": 128000,
986 "max_output_tokens": 8192,
987 "max_prompt_tokens": 120000
988 },
989 "object": "model_capabilities",
990 "supports": { "streaming": true, "tool_calls": true },
991 "type": "chat"
992 },
993 "id": "future-model-v1",
994 "is_chat_default": false,
995 "is_chat_fallback": false,
996 "model_picker_enabled": true,
997 "name": "Future Model v1",
998 "object": "model",
999 "preview": false,
1000 "vendor": "SomeNewVendor",
1001 "version": "v1.0"
1002 }
1003 ],
1004 "object": "list"
1005 }"#;
1006
1007 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1008
1009 assert_eq!(schema.data.len(), 1);
1010 assert_eq!(schema.data[0].id, "future-model-v1");
1011 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1012 }
1013}