copilot_chat.rs

  1use std::path::PathBuf;
  2use std::sync::Arc;
  3use std::sync::OnceLock;
  4
  5use anyhow::Context as _;
  6use anyhow::{Result, anyhow};
  7use chrono::DateTime;
  8use collections::HashSet;
  9use fs::Fs;
 10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 11use gpui::WeakEntity;
 12use gpui::{App, AsyncApp, Global, prelude::*};
 13use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 14use itertools::Itertools;
 15use paths::home_dir;
 16use serde::{Deserialize, Serialize};
 17use settings::watch_config_dir;
 18
 19pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
 20
 21#[derive(Default, Clone, Debug, PartialEq)]
 22pub struct CopilotChatConfiguration {
 23    pub enterprise_uri: Option<String>,
 24}
 25
 26impl CopilotChatConfiguration {
 27    pub fn token_url(&self) -> String {
 28        if let Some(enterprise_uri) = &self.enterprise_uri {
 29            let domain = Self::parse_domain(enterprise_uri);
 30            format!("https://api.{}/copilot_internal/v2/token", domain)
 31        } else {
 32            "https://api.github.com/copilot_internal/v2/token".to_string()
 33        }
 34    }
 35
 36    pub fn oauth_domain(&self) -> String {
 37        if let Some(enterprise_uri) = &self.enterprise_uri {
 38            Self::parse_domain(enterprise_uri)
 39        } else {
 40            "github.com".to_string()
 41        }
 42    }
 43
 44    pub fn api_url_from_endpoint(&self, endpoint: &str) -> String {
 45        format!("{}/chat/completions", endpoint)
 46    }
 47
 48    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
 49        format!("{}/models", endpoint)
 50    }
 51
 52    fn parse_domain(enterprise_uri: &str) -> String {
 53        let uri = enterprise_uri.trim_end_matches('/');
 54
 55        if let Some(domain) = uri.strip_prefix("https://") {
 56            domain.split('/').next().unwrap_or(domain).to_string()
 57        } else if let Some(domain) = uri.strip_prefix("http://") {
 58            domain.split('/').next().unwrap_or(domain).to_string()
 59        } else {
 60            uri.split('/').next().unwrap_or(uri).to_string()
 61        }
 62    }
 63}
 64
 65#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 66#[serde(rename_all = "lowercase")]
 67pub enum Role {
 68    User,
 69    Assistant,
 70    System,
 71}
 72
 73#[derive(Deserialize)]
 74struct ModelSchema {
 75    #[serde(deserialize_with = "deserialize_models_skip_errors")]
 76    data: Vec<Model>,
 77}
 78
 79fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
 80where
 81    D: serde::Deserializer<'de>,
 82{
 83    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
 84    let models = raw_values
 85        .into_iter()
 86        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 87            Ok(model) => Some(model),
 88            Err(err) => {
 89                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 90                None
 91            }
 92        })
 93        .collect();
 94
 95    Ok(models)
 96}
 97
 98#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 99pub struct Model {
100    billing: ModelBilling,
101    capabilities: ModelCapabilities,
102    id: String,
103    name: String,
104    policy: Option<ModelPolicy>,
105    vendor: ModelVendor,
106    is_chat_default: bool,
107    // The model with this value true is selected by VSCode copilot if a premium request limit is
108    // reached. Zed does not currently implement this behaviour
109    is_chat_fallback: bool,
110    model_picker_enabled: bool,
111}
112
113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
114struct ModelBilling {
115    is_premium: bool,
116    multiplier: f64,
117    // List of plans a model is restricted to
118    // Field is not present if a model is available for all plans
119    #[serde(default)]
120    restricted_to: Option<Vec<String>>,
121}
122
123#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
124struct ModelCapabilities {
125    family: String,
126    #[serde(default)]
127    limits: ModelLimits,
128    supports: ModelSupportedFeatures,
129    #[serde(rename = "type")]
130    model_type: String,
131}
132
133#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
134struct ModelLimits {
135    #[serde(default)]
136    max_context_window_tokens: usize,
137    #[serde(default)]
138    max_output_tokens: usize,
139    #[serde(default)]
140    max_prompt_tokens: u64,
141}
142
143#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
144struct ModelPolicy {
145    state: String,
146}
147
148#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
149struct ModelSupportedFeatures {
150    #[serde(default)]
151    streaming: bool,
152    #[serde(default)]
153    tool_calls: bool,
154    #[serde(default)]
155    parallel_tool_calls: bool,
156    #[serde(default)]
157    vision: bool,
158}
159
160#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
161pub enum ModelVendor {
162    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
163    #[serde(alias = "Azure OpenAI")]
164    OpenAI,
165    Google,
166    Anthropic,
167}
168
169#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
170#[serde(tag = "type")]
171pub enum ChatMessagePart {
172    #[serde(rename = "text")]
173    Text { text: String },
174    #[serde(rename = "image_url")]
175    Image { image_url: ImageUrl },
176}
177
178#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
179pub struct ImageUrl {
180    pub url: String,
181}
182
183impl Model {
184    pub fn uses_streaming(&self) -> bool {
185        self.capabilities.supports.streaming
186    }
187
188    pub fn id(&self) -> &str {
189        self.id.as_str()
190    }
191
192    pub fn display_name(&self) -> &str {
193        self.name.as_str()
194    }
195
196    pub fn max_token_count(&self) -> u64 {
197        self.capabilities.limits.max_prompt_tokens
198    }
199
200    pub fn supports_tools(&self) -> bool {
201        self.capabilities.supports.tool_calls
202    }
203
204    pub fn vendor(&self) -> ModelVendor {
205        self.vendor
206    }
207
208    pub fn supports_vision(&self) -> bool {
209        self.capabilities.supports.vision
210    }
211
212    pub fn supports_parallel_tool_calls(&self) -> bool {
213        self.capabilities.supports.parallel_tool_calls
214    }
215}
216
217#[derive(Serialize, Deserialize)]
218pub struct Request {
219    pub intent: bool,
220    pub n: usize,
221    pub stream: bool,
222    pub temperature: f32,
223    pub model: String,
224    pub messages: Vec<ChatMessage>,
225    #[serde(default, skip_serializing_if = "Vec::is_empty")]
226    pub tools: Vec<Tool>,
227    #[serde(default, skip_serializing_if = "Option::is_none")]
228    pub tool_choice: Option<ToolChoice>,
229}
230
231#[derive(Serialize, Deserialize)]
232pub struct Function {
233    pub name: String,
234    pub description: String,
235    pub parameters: serde_json::Value,
236}
237
238#[derive(Serialize, Deserialize)]
239#[serde(tag = "type", rename_all = "snake_case")]
240pub enum Tool {
241    Function { function: Function },
242}
243
244#[derive(Serialize, Deserialize)]
245#[serde(rename_all = "lowercase")]
246pub enum ToolChoice {
247    Auto,
248    Any,
249    None,
250}
251
252#[derive(Serialize, Deserialize, Debug)]
253#[serde(tag = "role", rename_all = "lowercase")]
254pub enum ChatMessage {
255    Assistant {
256        content: ChatMessageContent,
257        #[serde(default, skip_serializing_if = "Vec::is_empty")]
258        tool_calls: Vec<ToolCall>,
259    },
260    User {
261        content: ChatMessageContent,
262    },
263    System {
264        content: String,
265    },
266    Tool {
267        content: ChatMessageContent,
268        tool_call_id: String,
269    },
270}
271
272#[derive(Debug, Serialize, Deserialize)]
273#[serde(untagged)]
274pub enum ChatMessageContent {
275    Plain(String),
276    Multipart(Vec<ChatMessagePart>),
277}
278
279impl ChatMessageContent {
280    pub fn empty() -> Self {
281        ChatMessageContent::Multipart(vec![])
282    }
283}
284
285impl From<Vec<ChatMessagePart>> for ChatMessageContent {
286    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
287        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
288            ChatMessageContent::Plain(std::mem::take(text))
289        } else {
290            ChatMessageContent::Multipart(parts)
291        }
292    }
293}
294
295impl From<String> for ChatMessageContent {
296    fn from(text: String) -> Self {
297        ChatMessageContent::Plain(text)
298    }
299}
300
301#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
302pub struct ToolCall {
303    pub id: String,
304    #[serde(flatten)]
305    pub content: ToolCallContent,
306}
307
308#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
309#[serde(tag = "type", rename_all = "lowercase")]
310pub enum ToolCallContent {
311    Function { function: FunctionContent },
312}
313
314#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
315pub struct FunctionContent {
316    pub name: String,
317    pub arguments: String,
318}
319
320#[derive(Deserialize, Debug)]
321#[serde(tag = "type", rename_all = "snake_case")]
322pub struct ResponseEvent {
323    pub choices: Vec<ResponseChoice>,
324    pub id: String,
325    pub usage: Option<Usage>,
326}
327
328#[derive(Deserialize, Debug)]
329pub struct Usage {
330    pub completion_tokens: u64,
331    pub prompt_tokens: u64,
332    pub total_tokens: u64,
333}
334
335#[derive(Debug, Deserialize)]
336pub struct ResponseChoice {
337    pub index: usize,
338    pub finish_reason: Option<String>,
339    pub delta: Option<ResponseDelta>,
340    pub message: Option<ResponseDelta>,
341}
342
343#[derive(Debug, Deserialize)]
344pub struct ResponseDelta {
345    pub content: Option<String>,
346    pub role: Option<Role>,
347    #[serde(default)]
348    pub tool_calls: Vec<ToolCallChunk>,
349}
350
351#[derive(Deserialize, Debug, Eq, PartialEq)]
352pub struct ToolCallChunk {
353    pub index: usize,
354    pub id: Option<String>,
355    pub function: Option<FunctionChunk>,
356}
357
358#[derive(Deserialize, Debug, Eq, PartialEq)]
359pub struct FunctionChunk {
360    pub name: Option<String>,
361    pub arguments: Option<String>,
362}
363
364#[derive(Deserialize)]
365struct ApiTokenResponse {
366    token: String,
367    expires_at: i64,
368    endpoints: ApiTokenResponseEndpoints,
369}
370
371#[derive(Deserialize)]
372struct ApiTokenResponseEndpoints {
373    api: String,
374}
375
376#[derive(Clone)]
377struct ApiToken {
378    api_key: String,
379    expires_at: DateTime<chrono::Utc>,
380    api_endpoint: String,
381}
382
383impl ApiToken {
384    pub fn remaining_seconds(&self) -> i64 {
385        self.expires_at
386            .timestamp()
387            .saturating_sub(chrono::Utc::now().timestamp())
388    }
389}
390
391impl TryFrom<ApiTokenResponse> for ApiToken {
392    type Error = anyhow::Error;
393
394    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
395        let expires_at =
396            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
397
398        Ok(Self {
399            api_key: response.token,
400            expires_at,
401            api_endpoint: response.endpoints.api,
402        })
403    }
404}
405
406struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
407
408impl Global for GlobalCopilotChat {}
409
410pub struct CopilotChat {
411    oauth_token: Option<String>,
412    api_token: Option<ApiToken>,
413    configuration: CopilotChatConfiguration,
414    models: Option<Vec<Model>>,
415    client: Arc<dyn HttpClient>,
416}
417
418pub fn init(
419    fs: Arc<dyn Fs>,
420    client: Arc<dyn HttpClient>,
421    configuration: CopilotChatConfiguration,
422    cx: &mut App,
423) {
424    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
425    cx.set_global(GlobalCopilotChat(copilot_chat));
426}
427
428pub fn copilot_chat_config_dir() -> &'static PathBuf {
429    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
430
431    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
432        let config_dir = if cfg!(target_os = "windows") {
433            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
434        } else {
435            std::env::var("XDG_CONFIG_HOME")
436                .map(PathBuf::from)
437                .unwrap_or_else(|_| home_dir().join(".config"))
438        };
439
440        config_dir.join("github-copilot")
441    })
442}
443
444fn copilot_chat_config_paths() -> [PathBuf; 2] {
445    let base_dir = copilot_chat_config_dir();
446    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
447}
448
449impl CopilotChat {
450    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
451        cx.try_global::<GlobalCopilotChat>()
452            .map(|model| model.0.clone())
453    }
454
455    fn new(
456        fs: Arc<dyn Fs>,
457        client: Arc<dyn HttpClient>,
458        configuration: CopilotChatConfiguration,
459        cx: &mut Context<Self>,
460    ) -> Self {
461        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
462        let dir_path = copilot_chat_config_dir();
463
464        cx.spawn(async move |this, cx| {
465            let mut parent_watch_rx = watch_config_dir(
466                cx.background_executor(),
467                fs.clone(),
468                dir_path.clone(),
469                config_paths,
470            );
471            while let Some(contents) = parent_watch_rx.next().await {
472                let oauth_domain =
473                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
474                let oauth_token = extract_oauth_token(contents, &oauth_domain);
475
476                this.update(cx, |this, cx| {
477                    this.oauth_token = oauth_token.clone();
478                    cx.notify();
479                })?;
480
481                if oauth_token.is_some() {
482                    Self::update_models(&this, cx).await?;
483                }
484            }
485            anyhow::Ok(())
486        })
487        .detach_and_log_err(cx);
488
489        let this = Self {
490            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
491            api_token: None,
492            models: None,
493            configuration,
494            client,
495        };
496
497        if this.oauth_token.is_some() {
498            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
499                .detach_and_log_err(cx);
500        }
501
502        this
503    }
504
505    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
506        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
507            (
508                this.oauth_token.clone(),
509                this.client.clone(),
510                this.configuration.clone(),
511            )
512        })?;
513
514        let oauth_token = oauth_token
515            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
516
517        let token_url = configuration.token_url();
518        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
519
520        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
521        let models =
522            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
523
524        this.update(cx, |this, cx| {
525            this.api_token = Some(api_token);
526            this.models = Some(models);
527            cx.notify();
528        })?;
529        anyhow::Ok(())
530    }
531
532    pub fn is_authenticated(&self) -> bool {
533        self.oauth_token.is_some()
534    }
535
536    pub fn models(&self) -> Option<&[Model]> {
537        self.models.as_deref()
538    }
539
540    pub async fn stream_completion(
541        request: Request,
542        is_user_initiated: bool,
543        mut cx: AsyncApp,
544    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
545        let this = cx
546            .update(|cx| Self::global(cx))
547            .ok()
548            .flatten()
549            .context("Copilot chat is not enabled")?;
550
551        let (oauth_token, api_token, client, configuration) = this.read_with(&cx, |this, _| {
552            (
553                this.oauth_token.clone(),
554                this.api_token.clone(),
555                this.client.clone(),
556                this.configuration.clone(),
557            )
558        })?;
559
560        let oauth_token = oauth_token.context("No OAuth token available")?;
561
562        let token = match api_token {
563            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(),
564            _ => {
565                let token_url = configuration.token_url();
566                let token =
567                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
568                this.update(&mut cx, |this, cx| {
569                    this.api_token = Some(token.clone());
570                    cx.notify();
571                })?;
572                token
573            }
574        };
575
576        let api_url = configuration.api_url_from_endpoint(&token.api_endpoint);
577        stream_completion(
578            client.clone(),
579            token.api_key,
580            api_url.into(),
581            request,
582            is_user_initiated,
583        )
584        .await
585    }
586
587    pub fn set_configuration(
588        &mut self,
589        configuration: CopilotChatConfiguration,
590        cx: &mut Context<Self>,
591    ) {
592        let same_configuration = self.configuration == configuration;
593        self.configuration = configuration;
594        if !same_configuration {
595            self.api_token = None;
596            cx.spawn(async move |this, cx| {
597                Self::update_models(&this, cx).await?;
598                Ok::<_, anyhow::Error>(())
599            })
600            .detach();
601        }
602    }
603}
604
605async fn get_models(
606    models_url: Arc<str>,
607    api_token: String,
608    client: Arc<dyn HttpClient>,
609) -> Result<Vec<Model>> {
610    let all_models = request_models(models_url, api_token, client).await?;
611
612    let mut models: Vec<Model> = all_models
613        .into_iter()
614        .filter(|model| {
615            model.model_picker_enabled
616                && model.capabilities.model_type.as_str() == "chat"
617                && model
618                    .policy
619                    .as_ref()
620                    .is_none_or(|policy| policy.state == "enabled")
621        })
622        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
623        .collect();
624
625    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
626        let default_model = models.remove(default_model_position);
627        models.insert(0, default_model);
628    }
629
630    Ok(models)
631}
632
633async fn request_models(
634    models_url: Arc<str>,
635    api_token: String,
636    client: Arc<dyn HttpClient>,
637) -> Result<Vec<Model>> {
638    let request_builder = HttpRequest::builder()
639        .method(Method::GET)
640        .uri(models_url.as_ref())
641        .header("Authorization", format!("Bearer {}", api_token))
642        .header("Content-Type", "application/json")
643        .header("Copilot-Integration-Id", "vscode-chat")
644        .header("Editor-Version", "vscode/1.103.2")
645        .header("x-github-api-version", "2025-05-01");
646
647    let request = request_builder.body(AsyncBody::empty())?;
648
649    let mut response = client.send(request).await?;
650
651    anyhow::ensure!(
652        response.status().is_success(),
653        "Failed to request models: {}",
654        response.status()
655    );
656    let mut body = Vec::new();
657    response.body_mut().read_to_end(&mut body).await?;
658
659    let body_str = std::str::from_utf8(&body)?;
660
661    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
662
663    Ok(models)
664}
665
666async fn request_api_token(
667    oauth_token: &str,
668    auth_url: Arc<str>,
669    client: Arc<dyn HttpClient>,
670) -> Result<ApiToken> {
671    let request_builder = HttpRequest::builder()
672        .method(Method::GET)
673        .uri(auth_url.as_ref())
674        .header("Authorization", format!("token {}", oauth_token))
675        .header("Accept", "application/json");
676
677    let request = request_builder.body(AsyncBody::empty())?;
678
679    let mut response = client.send(request).await?;
680
681    if response.status().is_success() {
682        let mut body = Vec::new();
683        response.body_mut().read_to_end(&mut body).await?;
684
685        let body_str = std::str::from_utf8(&body)?;
686
687        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
688        ApiToken::try_from(parsed)
689    } else {
690        let mut body = Vec::new();
691        response.body_mut().read_to_end(&mut body).await?;
692
693        let body_str = std::str::from_utf8(&body)?;
694        anyhow::bail!("Failed to request API token: {body_str}");
695    }
696}
697
698fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
699    serde_json::from_str::<serde_json::Value>(&contents)
700        .map(|v| {
701            v.as_object().and_then(|obj| {
702                obj.iter().find_map(|(key, value)| {
703                    if key.starts_with(domain) {
704                        value["oauth_token"].as_str().map(|v| v.to_string())
705                    } else {
706                        None
707                    }
708                })
709            })
710        })
711        .ok()
712        .flatten()
713}
714
715async fn stream_completion(
716    client: Arc<dyn HttpClient>,
717    api_key: String,
718    completion_url: Arc<str>,
719    request: Request,
720    is_user_initiated: bool,
721) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
722    let is_vision_request = request.messages.iter().any(|message| match message {
723      ChatMessage::User { content }
724      | ChatMessage::Assistant { content, .. }
725      | ChatMessage::Tool { content, .. } => {
726          matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
727      }
728      _ => false,
729  });
730
731    let request_initiator = if is_user_initiated { "user" } else { "agent" };
732
733    let mut request_builder = HttpRequest::builder()
734        .method(Method::POST)
735        .uri(completion_url.as_ref())
736        .header(
737            "Editor-Version",
738            format!(
739                "Zed/{}",
740                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
741            ),
742        )
743        .header("Authorization", format!("Bearer {}", api_key))
744        .header("Content-Type", "application/json")
745        .header("Copilot-Integration-Id", "vscode-chat")
746        .header("X-Initiator", request_initiator);
747
748    if is_vision_request {
749        request_builder =
750            request_builder.header("Copilot-Vision-Request", is_vision_request.to_string());
751    }
752
753    let is_streaming = request.stream;
754
755    let json = serde_json::to_string(&request)?;
756    let request = request_builder.body(AsyncBody::from(json))?;
757    let mut response = client.send(request).await?;
758
759    if !response.status().is_success() {
760        let mut body = Vec::new();
761        response.body_mut().read_to_end(&mut body).await?;
762        let body_str = std::str::from_utf8(&body)?;
763        anyhow::bail!(
764            "Failed to connect to API: {} {}",
765            response.status(),
766            body_str
767        );
768    }
769
770    if is_streaming {
771        let reader = BufReader::new(response.into_body());
772        Ok(reader
773            .lines()
774            .filter_map(|line| async move {
775                match line {
776                    Ok(line) => {
777                        let line = line.strip_prefix("data: ")?;
778                        if line.starts_with("[DONE]") {
779                            return None;
780                        }
781
782                        match serde_json::from_str::<ResponseEvent>(line) {
783                            Ok(response) => {
784                                if response.choices.is_empty() {
785                                    None
786                                } else {
787                                    Some(Ok(response))
788                                }
789                            }
790                            Err(error) => Some(Err(anyhow!(error))),
791                        }
792                    }
793                    Err(error) => Some(Err(anyhow!(error))),
794                }
795            })
796            .boxed())
797    } else {
798        let mut body = Vec::new();
799        response.body_mut().read_to_end(&mut body).await?;
800        let body_str = std::str::from_utf8(&body)?;
801        let response: ResponseEvent = serde_json::from_str(body_str)?;
802
803        Ok(futures::stream::once(async move { Ok(response) }).boxed())
804    }
805}
806
807#[cfg(test)]
808mod tests {
809    use super::*;
810
811    #[test]
812    fn test_resilient_model_schema_deserialize() {
813        let json = r#"{
814              "data": [
815                {
816                  "billing": {
817                    "is_premium": false,
818                    "multiplier": 0
819                  },
820                  "capabilities": {
821                    "family": "gpt-4",
822                    "limits": {
823                      "max_context_window_tokens": 32768,
824                      "max_output_tokens": 4096,
825                      "max_prompt_tokens": 32768
826                    },
827                    "object": "model_capabilities",
828                    "supports": { "streaming": true, "tool_calls": true },
829                    "tokenizer": "cl100k_base",
830                    "type": "chat"
831                  },
832                  "id": "gpt-4",
833                  "is_chat_default": false,
834                  "is_chat_fallback": false,
835                  "model_picker_enabled": false,
836                  "name": "GPT 4",
837                  "object": "model",
838                  "preview": false,
839                  "vendor": "Azure OpenAI",
840                  "version": "gpt-4-0613"
841                },
842                {
843                    "some-unknown-field": 123
844                },
845                {
846                  "billing": {
847                    "is_premium": true,
848                    "multiplier": 1,
849                    "restricted_to": [
850                      "pro",
851                      "pro_plus",
852                      "business",
853                      "enterprise"
854                    ]
855                  },
856                  "capabilities": {
857                    "family": "claude-3.7-sonnet",
858                    "limits": {
859                      "max_context_window_tokens": 200000,
860                      "max_output_tokens": 16384,
861                      "max_prompt_tokens": 90000,
862                      "vision": {
863                        "max_prompt_image_size": 3145728,
864                        "max_prompt_images": 1,
865                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
866                      }
867                    },
868                    "object": "model_capabilities",
869                    "supports": {
870                      "parallel_tool_calls": true,
871                      "streaming": true,
872                      "tool_calls": true,
873                      "vision": true
874                    },
875                    "tokenizer": "o200k_base",
876                    "type": "chat"
877                  },
878                  "id": "claude-3.7-sonnet",
879                  "is_chat_default": false,
880                  "is_chat_fallback": false,
881                  "model_picker_enabled": true,
882                  "name": "Claude 3.7 Sonnet",
883                  "object": "model",
884                  "policy": {
885                    "state": "enabled",
886                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
887                  },
888                  "preview": false,
889                  "vendor": "Anthropic",
890                  "version": "claude-3.7-sonnet"
891                }
892              ],
893              "object": "list"
894            }"#;
895
896        let schema: ModelSchema = serde_json::from_str(json).unwrap();
897
898        assert_eq!(schema.data.len(), 2);
899        assert_eq!(schema.data[0].id, "gpt-4");
900        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
901    }
902}