1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::OnceLock;
4
5use anyhow::Context as _;
6use anyhow::{Result, anyhow};
7use chrono::DateTime;
8use collections::HashSet;
9use fs::Fs;
10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
11use gpui::WeakEntity;
12use gpui::{App, AsyncApp, Global, prelude::*};
13use http_client::HttpRequestExt;
14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
15use itertools::Itertools;
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use crate::copilot_responses as responses;
20use settings::watch_config_dir;
21
22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn token_url(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 let domain = Self::parse_domain(enterprise_uri);
33 format!("https://api.{}/copilot_internal/v2/token", domain)
34 } else {
35 "https://api.github.com/copilot_internal/v2/token".to_string()
36 }
37 }
38
39 pub fn oauth_domain(&self) -> String {
40 if let Some(enterprise_uri) = &self.enterprise_uri {
41 Self::parse_domain(enterprise_uri)
42 } else {
43 "github.com".to_string()
44 }
45 }
46
47 pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
48 format!("{}/chat/completions", endpoint)
49 }
50
51 pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
52 format!("{}/responses", endpoint)
53 }
54
55 pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
56 format!("{}/models", endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86}
87
88#[derive(Deserialize)]
89struct ModelSchema {
90 #[serde(deserialize_with = "deserialize_models_skip_errors")]
91 data: Vec<Model>,
92}
93
94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
95where
96 D: serde::Deserializer<'de>,
97{
98 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
99 let models = raw_values
100 .into_iter()
101 .filter_map(|value| match serde_json::from_value::<Model>(value) {
102 Ok(model) => Some(model),
103 Err(err) => {
104 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
105 None
106 }
107 })
108 .collect();
109
110 Ok(models)
111}
112
113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
114pub struct Model {
115 billing: ModelBilling,
116 capabilities: ModelCapabilities,
117 id: String,
118 name: String,
119 policy: Option<ModelPolicy>,
120 vendor: ModelVendor,
121 is_chat_default: bool,
122 // The model with this value true is selected by VSCode copilot if a premium request limit is
123 // reached. Zed does not currently implement this behaviour
124 is_chat_fallback: bool,
125 model_picker_enabled: bool,
126 #[serde(default)]
127 supported_endpoints: Vec<ModelSupportedEndpoint>,
128}
129
130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
131struct ModelBilling {
132 is_premium: bool,
133 multiplier: f64,
134 // List of plans a model is restricted to
135 // Field is not present if a model is available for all plans
136 #[serde(default)]
137 restricted_to: Option<Vec<String>>,
138}
139
140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
141struct ModelCapabilities {
142 family: String,
143 #[serde(default)]
144 limits: ModelLimits,
145 supports: ModelSupportedFeatures,
146 #[serde(rename = "type")]
147 model_type: String,
148 #[serde(default)]
149 tokenizer: Option<String>,
150}
151
152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
153struct ModelLimits {
154 #[serde(default)]
155 max_context_window_tokens: usize,
156 #[serde(default)]
157 max_output_tokens: usize,
158 #[serde(default)]
159 max_prompt_tokens: u64,
160}
161
162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
163struct ModelPolicy {
164 state: String,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelSupportedFeatures {
169 #[serde(default)]
170 streaming: bool,
171 #[serde(default)]
172 tool_calls: bool,
173 #[serde(default)]
174 parallel_tool_calls: bool,
175 #[serde(default)]
176 vision: bool,
177}
178
179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
180pub enum ModelVendor {
181 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
182 #[serde(alias = "Azure OpenAI")]
183 OpenAI,
184 Google,
185 Anthropic,
186 #[serde(rename = "xAI")]
187 XAI,
188 /// Unknown vendor that we don't explicitly support yet
189 #[serde(other)]
190 Unknown,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
194#[serde(tag = "type")]
195pub enum ChatMessagePart {
196 #[serde(rename = "text")]
197 Text { text: String },
198 #[serde(rename = "image_url")]
199 Image { image_url: ImageUrl },
200}
201
202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
203pub struct ImageUrl {
204 pub url: String,
205}
206
207impl Model {
208 pub fn uses_streaming(&self) -> bool {
209 self.capabilities.supports.streaming
210 }
211
212 pub fn id(&self) -> &str {
213 self.id.as_str()
214 }
215
216 pub fn display_name(&self) -> &str {
217 self.name.as_str()
218 }
219
220 pub fn max_token_count(&self) -> u64 {
221 self.capabilities.limits.max_prompt_tokens
222 }
223
224 pub fn supports_tools(&self) -> bool {
225 self.capabilities.supports.tool_calls
226 }
227
228 pub fn vendor(&self) -> ModelVendor {
229 self.vendor
230 }
231
232 pub fn supports_vision(&self) -> bool {
233 self.capabilities.supports.vision
234 }
235
236 pub fn supports_parallel_tool_calls(&self) -> bool {
237 self.capabilities.supports.parallel_tool_calls
238 }
239
240 pub fn tokenizer(&self) -> Option<&str> {
241 self.capabilities.tokenizer.as_deref()
242 }
243
244 pub fn supports_response(&self) -> bool {
245 self.supported_endpoints.len() > 0
246 && !self
247 .supported_endpoints
248 .contains(&ModelSupportedEndpoint::ChatCompletions)
249 && self
250 .supported_endpoints
251 .contains(&ModelSupportedEndpoint::Responses)
252 }
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct Request {
257 pub intent: bool,
258 pub n: usize,
259 pub stream: bool,
260 pub temperature: f32,
261 pub model: String,
262 pub messages: Vec<ChatMessage>,
263 #[serde(default, skip_serializing_if = "Vec::is_empty")]
264 pub tools: Vec<Tool>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub tool_choice: Option<ToolChoice>,
267}
268
269#[derive(Serialize, Deserialize)]
270pub struct Function {
271 pub name: String,
272 pub description: String,
273 pub parameters: serde_json::Value,
274}
275
276#[derive(Serialize, Deserialize)]
277#[serde(tag = "type", rename_all = "snake_case")]
278pub enum Tool {
279 Function { function: Function },
280}
281
282#[derive(Serialize, Deserialize, Debug)]
283#[serde(rename_all = "lowercase")]
284pub enum ToolChoice {
285 Auto,
286 Any,
287 None,
288}
289
290#[derive(Serialize, Deserialize, Debug)]
291#[serde(tag = "role", rename_all = "lowercase")]
292pub enum ChatMessage {
293 Assistant {
294 content: ChatMessageContent,
295 #[serde(default, skip_serializing_if = "Vec::is_empty")]
296 tool_calls: Vec<ToolCall>,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 reasoning_opaque: Option<String>,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 reasoning_text: Option<String>,
301 },
302 User {
303 content: ChatMessageContent,
304 },
305 System {
306 content: String,
307 },
308 Tool {
309 content: ChatMessageContent,
310 tool_call_id: String,
311 },
312}
313
314#[derive(Debug, Serialize, Deserialize)]
315#[serde(untagged)]
316pub enum ChatMessageContent {
317 Plain(String),
318 Multipart(Vec<ChatMessagePart>),
319}
320
321impl ChatMessageContent {
322 pub fn empty() -> Self {
323 ChatMessageContent::Multipart(vec![])
324 }
325}
326
327impl From<Vec<ChatMessagePart>> for ChatMessageContent {
328 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
329 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
330 ChatMessageContent::Plain(std::mem::take(text))
331 } else {
332 ChatMessageContent::Multipart(parts)
333 }
334 }
335}
336
337impl From<String> for ChatMessageContent {
338 fn from(text: String) -> Self {
339 ChatMessageContent::Plain(text)
340 }
341}
342
343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
344pub struct ToolCall {
345 pub id: String,
346 #[serde(flatten)]
347 pub content: ToolCallContent,
348}
349
350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
351#[serde(tag = "type", rename_all = "lowercase")]
352pub enum ToolCallContent {
353 Function { function: FunctionContent },
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct FunctionContent {
358 pub name: String,
359 pub arguments: String,
360 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub thought_signature: Option<String>,
362}
363
364#[derive(Deserialize, Debug)]
365#[serde(tag = "type", rename_all = "snake_case")]
366pub struct ResponseEvent {
367 pub choices: Vec<ResponseChoice>,
368 pub id: String,
369 pub usage: Option<Usage>,
370}
371
372#[derive(Deserialize, Debug)]
373pub struct Usage {
374 pub completion_tokens: u64,
375 pub prompt_tokens: u64,
376 pub total_tokens: u64,
377}
378
379#[derive(Debug, Deserialize)]
380pub struct ResponseChoice {
381 pub index: Option<usize>,
382 pub finish_reason: Option<String>,
383 pub delta: Option<ResponseDelta>,
384 pub message: Option<ResponseDelta>,
385}
386
387#[derive(Debug, Deserialize)]
388pub struct ResponseDelta {
389 pub content: Option<String>,
390 pub role: Option<Role>,
391 #[serde(default)]
392 pub tool_calls: Vec<ToolCallChunk>,
393 pub reasoning_opaque: Option<String>,
394 pub reasoning_text: Option<String>,
395}
396#[derive(Deserialize, Debug, Eq, PartialEq)]
397pub struct ToolCallChunk {
398 pub index: Option<usize>,
399 pub id: Option<String>,
400 pub function: Option<FunctionChunk>,
401}
402
403#[derive(Deserialize, Debug, Eq, PartialEq)]
404pub struct FunctionChunk {
405 pub name: Option<String>,
406 pub arguments: Option<String>,
407 pub thought_signature: Option<String>,
408}
409
410#[derive(Deserialize)]
411struct ApiTokenResponse {
412 token: String,
413 expires_at: i64,
414 endpoints: ApiTokenResponseEndpoints,
415}
416
417#[derive(Deserialize)]
418struct ApiTokenResponseEndpoints {
419 api: String,
420}
421
422#[derive(Clone)]
423struct ApiToken {
424 api_key: String,
425 expires_at: DateTime<chrono::Utc>,
426 api_endpoint: String,
427}
428
429impl ApiToken {
430 pub fn remaining_seconds(&self) -> i64 {
431 self.expires_at
432 .timestamp()
433 .saturating_sub(chrono::Utc::now().timestamp())
434 }
435}
436
437impl TryFrom<ApiTokenResponse> for ApiToken {
438 type Error = anyhow::Error;
439
440 fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
441 let expires_at =
442 DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
443
444 Ok(Self {
445 api_key: response.token,
446 expires_at,
447 api_endpoint: response.endpoints.api,
448 })
449 }
450}
451
452struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
453
454impl Global for GlobalCopilotChat {}
455
456pub struct CopilotChat {
457 oauth_token: Option<String>,
458 api_token: Option<ApiToken>,
459 configuration: CopilotChatConfiguration,
460 models: Option<Vec<Model>>,
461 client: Arc<dyn HttpClient>,
462}
463
464pub fn init(
465 fs: Arc<dyn Fs>,
466 client: Arc<dyn HttpClient>,
467 configuration: CopilotChatConfiguration,
468 cx: &mut App,
469) {
470 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
471 cx.set_global(GlobalCopilotChat(copilot_chat));
472}
473
474pub fn copilot_chat_config_dir() -> &'static PathBuf {
475 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
476
477 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
478 let config_dir = if cfg!(target_os = "windows") {
479 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
480 } else {
481 std::env::var("XDG_CONFIG_HOME")
482 .map(PathBuf::from)
483 .unwrap_or_else(|_| home_dir().join(".config"))
484 };
485
486 config_dir.join("github-copilot")
487 })
488}
489
490fn copilot_chat_config_paths() -> [PathBuf; 2] {
491 let base_dir = copilot_chat_config_dir();
492 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
493}
494
495impl CopilotChat {
496 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
497 cx.try_global::<GlobalCopilotChat>()
498 .map(|model| model.0.clone())
499 }
500
501 fn new(
502 fs: Arc<dyn Fs>,
503 client: Arc<dyn HttpClient>,
504 configuration: CopilotChatConfiguration,
505 cx: &mut Context<Self>,
506 ) -> Self {
507 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
508 let dir_path = copilot_chat_config_dir();
509
510 cx.spawn(async move |this, cx| {
511 let mut parent_watch_rx = watch_config_dir(
512 cx.background_executor(),
513 fs.clone(),
514 dir_path.clone(),
515 config_paths,
516 );
517 while let Some(contents) = parent_watch_rx.next().await {
518 let oauth_domain =
519 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
520 let oauth_token = extract_oauth_token(contents, &oauth_domain);
521
522 this.update(cx, |this, cx| {
523 this.oauth_token = oauth_token.clone();
524 cx.notify();
525 })?;
526
527 if oauth_token.is_some() {
528 Self::update_models(&this, cx).await?;
529 }
530 }
531 anyhow::Ok(())
532 })
533 .detach_and_log_err(cx);
534
535 let this = Self {
536 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
537 api_token: None,
538 models: None,
539 configuration,
540 client,
541 };
542
543 if this.oauth_token.is_some() {
544 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
545 .detach_and_log_err(cx);
546 }
547
548 this
549 }
550
551 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
552 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
553 (
554 this.oauth_token.clone(),
555 this.client.clone(),
556 this.configuration.clone(),
557 )
558 })?;
559
560 let oauth_token = oauth_token
561 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
562
563 let token_url = configuration.token_url();
564 let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
565
566 let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
567 let models =
568 get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
569
570 this.update(cx, |this, cx| {
571 this.api_token = Some(api_token);
572 this.models = Some(models);
573 cx.notify();
574 })?;
575 anyhow::Ok(())
576 }
577
578 pub fn is_authenticated(&self) -> bool {
579 self.oauth_token.is_some()
580 }
581
582 pub fn models(&self) -> Option<&[Model]> {
583 self.models.as_deref()
584 }
585
586 pub async fn stream_completion(
587 request: Request,
588 is_user_initiated: bool,
589 mut cx: AsyncApp,
590 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
591 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
592
593 let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
594 stream_completion(
595 client.clone(),
596 token.api_key,
597 api_url.into(),
598 request,
599 is_user_initiated,
600 )
601 .await
602 }
603
604 pub async fn stream_response(
605 request: responses::Request,
606 is_user_initiated: bool,
607 mut cx: AsyncApp,
608 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
609 let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
610
611 let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
612 responses::stream_response(
613 client.clone(),
614 token.api_key,
615 api_url,
616 request,
617 is_user_initiated,
618 )
619 .await
620 }
621
622 async fn get_auth_details(
623 cx: &mut AsyncApp,
624 ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
625 let this = cx
626 .update(|cx| Self::global(cx))
627 .context("Copilot chat is not enabled")?;
628
629 let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
630 (
631 this.oauth_token.clone(),
632 this.api_token.clone(),
633 this.client.clone(),
634 this.configuration.clone(),
635 )
636 });
637
638 let oauth_token = oauth_token.context("No OAuth token available")?;
639
640 let token = match api_token {
641 Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
642 _ => {
643 let token_url = configuration.token_url();
644 let token =
645 request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
646 this.update(cx, |this, cx| {
647 this.api_token = Some(token.clone());
648 cx.notify();
649 });
650 token
651 }
652 };
653
654 Ok((client, token, configuration))
655 }
656
657 pub fn set_configuration(
658 &mut self,
659 configuration: CopilotChatConfiguration,
660 cx: &mut Context<Self>,
661 ) {
662 let same_configuration = self.configuration == configuration;
663 self.configuration = configuration;
664 if !same_configuration {
665 self.api_token = None;
666 cx.spawn(async move |this, cx| {
667 Self::update_models(&this, cx).await?;
668 Ok::<_, anyhow::Error>(())
669 })
670 .detach();
671 }
672 }
673}
674
675async fn get_models(
676 models_url: Arc<str>,
677 api_token: String,
678 client: Arc<dyn HttpClient>,
679) -> Result<Vec<Model>> {
680 let all_models = request_models(models_url, api_token, client).await?;
681
682 let mut models: Vec<Model> = all_models
683 .into_iter()
684 .filter(|model| {
685 model.model_picker_enabled
686 && model.capabilities.model_type.as_str() == "chat"
687 && model
688 .policy
689 .as_ref()
690 .is_none_or(|policy| policy.state == "enabled")
691 })
692 .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
693 .collect();
694
695 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
696 let default_model = models.remove(default_model_position);
697 models.insert(0, default_model);
698 }
699
700 Ok(models)
701}
702
703async fn request_models(
704 models_url: Arc<str>,
705 api_token: String,
706 client: Arc<dyn HttpClient>,
707) -> Result<Vec<Model>> {
708 let request_builder = HttpRequest::builder()
709 .method(Method::GET)
710 .uri(models_url.as_ref())
711 .header("Authorization", format!("Bearer {}", api_token))
712 .header("Content-Type", "application/json")
713 .header("Copilot-Integration-Id", "vscode-chat")
714 .header("Editor-Version", "vscode/1.103.2")
715 .header("x-github-api-version", "2025-05-01");
716
717 let request = request_builder.body(AsyncBody::empty())?;
718
719 let mut response = client.send(request).await?;
720
721 anyhow::ensure!(
722 response.status().is_success(),
723 "Failed to request models: {}",
724 response.status()
725 );
726 let mut body = Vec::new();
727 response.body_mut().read_to_end(&mut body).await?;
728
729 let body_str = std::str::from_utf8(&body)?;
730
731 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
732
733 Ok(models)
734}
735
736async fn request_api_token(
737 oauth_token: &str,
738 auth_url: Arc<str>,
739 client: Arc<dyn HttpClient>,
740) -> Result<ApiToken> {
741 let request_builder = HttpRequest::builder()
742 .method(Method::GET)
743 .uri(auth_url.as_ref())
744 .header("Authorization", format!("token {}", oauth_token))
745 .header("Accept", "application/json");
746
747 let request = request_builder.body(AsyncBody::empty())?;
748
749 let mut response = client.send(request).await?;
750
751 if response.status().is_success() {
752 let mut body = Vec::new();
753 response.body_mut().read_to_end(&mut body).await?;
754
755 let body_str = std::str::from_utf8(&body)?;
756
757 let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
758 ApiToken::try_from(parsed)
759 } else {
760 let mut body = Vec::new();
761 response.body_mut().read_to_end(&mut body).await?;
762
763 let body_str = std::str::from_utf8(&body)?;
764 anyhow::bail!("Failed to request API token: {body_str}");
765 }
766}
767
768fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
769 serde_json::from_str::<serde_json::Value>(&contents)
770 .map(|v| {
771 v.as_object().and_then(|obj| {
772 obj.iter().find_map(|(key, value)| {
773 if key.starts_with(domain) {
774 value["oauth_token"].as_str().map(|v| v.to_string())
775 } else {
776 None
777 }
778 })
779 })
780 })
781 .ok()
782 .flatten()
783}
784
785async fn stream_completion(
786 client: Arc<dyn HttpClient>,
787 api_key: String,
788 completion_url: Arc<str>,
789 request: Request,
790 is_user_initiated: bool,
791) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
792 let is_vision_request = request.messages.iter().any(|message| match message {
793 ChatMessage::User { content }
794 | ChatMessage::Assistant { content, .. }
795 | ChatMessage::Tool { content, .. } => {
796 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
797 }
798 _ => false,
799 });
800
801 let request_initiator = if is_user_initiated { "user" } else { "agent" };
802
803 let request_builder = HttpRequest::builder()
804 .method(Method::POST)
805 .uri(completion_url.as_ref())
806 .header(
807 "Editor-Version",
808 format!(
809 "Zed/{}",
810 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
811 ),
812 )
813 .header("Authorization", format!("Bearer {}", api_key))
814 .header("Content-Type", "application/json")
815 .header("Copilot-Integration-Id", "vscode-chat")
816 .header("X-Initiator", request_initiator)
817 .when(is_vision_request, |builder| {
818 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
819 });
820
821 let is_streaming = request.stream;
822
823 let json = serde_json::to_string(&request)?;
824 let request = request_builder.body(AsyncBody::from(json))?;
825 let mut response = client.send(request).await?;
826
827 if !response.status().is_success() {
828 let mut body = Vec::new();
829 response.body_mut().read_to_end(&mut body).await?;
830 let body_str = std::str::from_utf8(&body)?;
831 anyhow::bail!(
832 "Failed to connect to API: {} {}",
833 response.status(),
834 body_str
835 );
836 }
837
838 if is_streaming {
839 let reader = BufReader::new(response.into_body());
840 Ok(reader
841 .lines()
842 .filter_map(|line| async move {
843 match line {
844 Ok(line) => {
845 let line = line.strip_prefix("data: ")?;
846 if line.starts_with("[DONE]") {
847 return None;
848 }
849
850 match serde_json::from_str::<ResponseEvent>(line) {
851 Ok(response) => {
852 if response.choices.is_empty() {
853 None
854 } else {
855 Some(Ok(response))
856 }
857 }
858 Err(error) => Some(Err(anyhow!(error))),
859 }
860 }
861 Err(error) => Some(Err(anyhow!(error))),
862 }
863 })
864 .boxed())
865 } else {
866 let mut body = Vec::new();
867 response.body_mut().read_to_end(&mut body).await?;
868 let body_str = std::str::from_utf8(&body)?;
869 let response: ResponseEvent = serde_json::from_str(body_str)?;
870
871 Ok(futures::stream::once(async move { Ok(response) }).boxed())
872 }
873}
874
875#[cfg(test)]
876mod tests {
877 use super::*;
878
879 #[test]
880 fn test_resilient_model_schema_deserialize() {
881 let json = r#"{
882 "data": [
883 {
884 "billing": {
885 "is_premium": false,
886 "multiplier": 0
887 },
888 "capabilities": {
889 "family": "gpt-4",
890 "limits": {
891 "max_context_window_tokens": 32768,
892 "max_output_tokens": 4096,
893 "max_prompt_tokens": 32768
894 },
895 "object": "model_capabilities",
896 "supports": { "streaming": true, "tool_calls": true },
897 "tokenizer": "cl100k_base",
898 "type": "chat"
899 },
900 "id": "gpt-4",
901 "is_chat_default": false,
902 "is_chat_fallback": false,
903 "model_picker_enabled": false,
904 "name": "GPT 4",
905 "object": "model",
906 "preview": false,
907 "vendor": "Azure OpenAI",
908 "version": "gpt-4-0613"
909 },
910 {
911 "some-unknown-field": 123
912 },
913 {
914 "billing": {
915 "is_premium": true,
916 "multiplier": 1,
917 "restricted_to": [
918 "pro",
919 "pro_plus",
920 "business",
921 "enterprise"
922 ]
923 },
924 "capabilities": {
925 "family": "claude-3.7-sonnet",
926 "limits": {
927 "max_context_window_tokens": 200000,
928 "max_output_tokens": 16384,
929 "max_prompt_tokens": 90000,
930 "vision": {
931 "max_prompt_image_size": 3145728,
932 "max_prompt_images": 1,
933 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
934 }
935 },
936 "object": "model_capabilities",
937 "supports": {
938 "parallel_tool_calls": true,
939 "streaming": true,
940 "tool_calls": true,
941 "vision": true
942 },
943 "tokenizer": "o200k_base",
944 "type": "chat"
945 },
946 "id": "claude-3.7-sonnet",
947 "is_chat_default": false,
948 "is_chat_fallback": false,
949 "model_picker_enabled": true,
950 "name": "Claude 3.7 Sonnet",
951 "object": "model",
952 "policy": {
953 "state": "enabled",
954 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
955 },
956 "preview": false,
957 "vendor": "Anthropic",
958 "version": "claude-3.7-sonnet"
959 }
960 ],
961 "object": "list"
962 }"#;
963
964 let schema: ModelSchema = serde_json::from_str(json).unwrap();
965
966 assert_eq!(schema.data.len(), 2);
967 assert_eq!(schema.data[0].id, "gpt-4");
968 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
969 }
970
971 #[test]
972 fn test_unknown_vendor_resilience() {
973 let json = r#"{
974 "data": [
975 {
976 "billing": {
977 "is_premium": false,
978 "multiplier": 1
979 },
980 "capabilities": {
981 "family": "future-model",
982 "limits": {
983 "max_context_window_tokens": 128000,
984 "max_output_tokens": 8192,
985 "max_prompt_tokens": 120000
986 },
987 "object": "model_capabilities",
988 "supports": { "streaming": true, "tool_calls": true },
989 "type": "chat"
990 },
991 "id": "future-model-v1",
992 "is_chat_default": false,
993 "is_chat_fallback": false,
994 "model_picker_enabled": true,
995 "name": "Future Model v1",
996 "object": "model",
997 "preview": false,
998 "vendor": "SomeNewVendor",
999 "version": "v1.0"
1000 }
1001 ],
1002 "object": "list"
1003 }"#;
1004
1005 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1006
1007 assert_eq!(schema.data.len(), 1);
1008 assert_eq!(schema.data[0].id, "future-model-v1");
1009 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1010 }
1011}