1pub mod responses;
2
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::sync::OnceLock;
6
7use anyhow::Context as _;
8use anyhow::{Result, anyhow};
9use collections::HashSet;
10use fs::Fs;
11use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
12use gpui::WeakEntity;
13use gpui::{App, AsyncApp, Global, prelude::*};
14use http_client::HttpRequestExt;
15use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use settings::watch_config_dir;
20
21pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
22const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn oauth_domain(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 Self::parse_domain(enterprise_uri)
33 } else {
34 "github.com".to_string()
35 }
36 }
37
38 pub fn graphql_url(&self) -> String {
39 if let Some(enterprise_uri) = &self.enterprise_uri {
40 let domain = Self::parse_domain(enterprise_uri);
41 format!("https://{}/api/graphql", domain)
42 } else {
43 "https://api.github.com/graphql".to_string()
44 }
45 }
46
47 pub fn chat_completions_url(&self, api_endpoint: &str) -> String {
48 format!("{}/chat/completions", api_endpoint)
49 }
50
51 pub fn responses_url(&self, api_endpoint: &str) -> String {
52 format!("{}/responses", api_endpoint)
53 }
54
55 pub fn models_url(&self, api_endpoint: &str) -> String {
56 format!("{}/models", api_endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86 #[serde(rename = "/v1/messages")]
87 Messages,
88 /// Unknown endpoint that we don't explicitly support yet
89 #[serde(other)]
90 Unknown,
91}
92
93#[derive(Deserialize)]
94struct ModelSchema {
95 #[serde(deserialize_with = "deserialize_models_skip_errors")]
96 data: Vec<Model>,
97}
98
99fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
100where
101 D: serde::Deserializer<'de>,
102{
103 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
104 let models = raw_values
105 .into_iter()
106 .filter_map(|value| match serde_json::from_value::<Model>(value) {
107 Ok(model) => Some(model),
108 Err(err) => {
109 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
110 None
111 }
112 })
113 .collect();
114
115 Ok(models)
116}
117
118#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
119pub struct Model {
120 billing: ModelBilling,
121 capabilities: ModelCapabilities,
122 id: String,
123 name: String,
124 policy: Option<ModelPolicy>,
125 vendor: ModelVendor,
126 is_chat_default: bool,
127 // The model with this value true is selected by VSCode copilot if a premium request limit is
128 // reached. Zed does not currently implement this behaviour
129 is_chat_fallback: bool,
130 model_picker_enabled: bool,
131 #[serde(default)]
132 supported_endpoints: Vec<ModelSupportedEndpoint>,
133}
134
135#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
136struct ModelBilling {
137 is_premium: bool,
138 multiplier: f64,
139 // List of plans a model is restricted to
140 // Field is not present if a model is available for all plans
141 #[serde(default)]
142 restricted_to: Option<Vec<String>>,
143}
144
145#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
146struct ModelCapabilities {
147 family: String,
148 #[serde(default)]
149 limits: ModelLimits,
150 supports: ModelSupportedFeatures,
151 #[serde(rename = "type")]
152 model_type: String,
153 #[serde(default)]
154 tokenizer: Option<String>,
155}
156
157#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
158struct ModelLimits {
159 #[serde(default)]
160 max_context_window_tokens: usize,
161 #[serde(default)]
162 max_output_tokens: usize,
163 #[serde(default)]
164 max_prompt_tokens: u64,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelPolicy {
169 state: String,
170}
171
172#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
173struct ModelSupportedFeatures {
174 #[serde(default)]
175 streaming: bool,
176 #[serde(default)]
177 tool_calls: bool,
178 #[serde(default)]
179 parallel_tool_calls: bool,
180 #[serde(default)]
181 vision: bool,
182}
183
184#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
185pub enum ModelVendor {
186 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
187 #[serde(alias = "Azure OpenAI")]
188 OpenAI,
189 Google,
190 Anthropic,
191 #[serde(rename = "xAI")]
192 XAI,
193 /// Unknown vendor that we don't explicitly support yet
194 #[serde(other)]
195 Unknown,
196}
197
198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
199#[serde(tag = "type")]
200pub enum ChatMessagePart {
201 #[serde(rename = "text")]
202 Text { text: String },
203 #[serde(rename = "image_url")]
204 Image { image_url: ImageUrl },
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
208pub struct ImageUrl {
209 pub url: String,
210}
211
212impl Model {
213 pub fn uses_streaming(&self) -> bool {
214 self.capabilities.supports.streaming
215 }
216
217 pub fn id(&self) -> &str {
218 self.id.as_str()
219 }
220
221 pub fn display_name(&self) -> &str {
222 self.name.as_str()
223 }
224
225 pub fn max_token_count(&self) -> u64 {
226 self.capabilities.limits.max_context_window_tokens as u64
227 }
228
229 pub fn supports_tools(&self) -> bool {
230 self.capabilities.supports.tool_calls
231 }
232
233 pub fn vendor(&self) -> ModelVendor {
234 self.vendor
235 }
236
237 pub fn supports_vision(&self) -> bool {
238 self.capabilities.supports.vision
239 }
240
241 pub fn supports_parallel_tool_calls(&self) -> bool {
242 self.capabilities.supports.parallel_tool_calls
243 }
244
245 pub fn tokenizer(&self) -> Option<&str> {
246 self.capabilities.tokenizer.as_deref()
247 }
248
249 pub fn supports_response(&self) -> bool {
250 self.supported_endpoints.len() > 0
251 && !self
252 .supported_endpoints
253 .contains(&ModelSupportedEndpoint::ChatCompletions)
254 && self
255 .supported_endpoints
256 .contains(&ModelSupportedEndpoint::Responses)
257 }
258}
259
260#[derive(Serialize, Deserialize)]
261pub struct Request {
262 pub intent: bool,
263 pub n: usize,
264 pub stream: bool,
265 pub temperature: f32,
266 pub model: String,
267 pub messages: Vec<ChatMessage>,
268 #[serde(default, skip_serializing_if = "Vec::is_empty")]
269 pub tools: Vec<Tool>,
270 #[serde(default, skip_serializing_if = "Option::is_none")]
271 pub tool_choice: Option<ToolChoice>,
272}
273
274#[derive(Serialize, Deserialize)]
275pub struct Function {
276 pub name: String,
277 pub description: String,
278 pub parameters: serde_json::Value,
279}
280
281#[derive(Serialize, Deserialize)]
282#[serde(tag = "type", rename_all = "snake_case")]
283pub enum Tool {
284 Function { function: Function },
285}
286
287#[derive(Serialize, Deserialize, Debug)]
288#[serde(rename_all = "lowercase")]
289pub enum ToolChoice {
290 Auto,
291 Any,
292 None,
293}
294
295#[derive(Serialize, Deserialize, Debug)]
296#[serde(tag = "role", rename_all = "lowercase")]
297pub enum ChatMessage {
298 Assistant {
299 content: ChatMessageContent,
300 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 tool_calls: Vec<ToolCall>,
302 #[serde(default, skip_serializing_if = "Option::is_none")]
303 reasoning_opaque: Option<String>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
305 reasoning_text: Option<String>,
306 },
307 User {
308 content: ChatMessageContent,
309 },
310 System {
311 content: String,
312 },
313 Tool {
314 content: ChatMessageContent,
315 tool_call_id: String,
316 },
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(untagged)]
321pub enum ChatMessageContent {
322 Plain(String),
323 Multipart(Vec<ChatMessagePart>),
324}
325
326impl ChatMessageContent {
327 pub fn empty() -> Self {
328 ChatMessageContent::Multipart(vec![])
329 }
330}
331
332impl From<Vec<ChatMessagePart>> for ChatMessageContent {
333 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
334 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
335 ChatMessageContent::Plain(std::mem::take(text))
336 } else {
337 ChatMessageContent::Multipart(parts)
338 }
339 }
340}
341
342impl From<String> for ChatMessageContent {
343 fn from(text: String) -> Self {
344 ChatMessageContent::Plain(text)
345 }
346}
347
348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
349pub struct ToolCall {
350 pub id: String,
351 #[serde(flatten)]
352 pub content: ToolCallContent,
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356#[serde(tag = "type", rename_all = "lowercase")]
357pub enum ToolCallContent {
358 Function { function: FunctionContent },
359}
360
361#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
362pub struct FunctionContent {
363 pub name: String,
364 pub arguments: String,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub thought_signature: Option<String>,
367}
368
369#[derive(Deserialize, Debug)]
370#[serde(tag = "type", rename_all = "snake_case")]
371pub struct ResponseEvent {
372 pub choices: Vec<ResponseChoice>,
373 pub id: String,
374 pub usage: Option<Usage>,
375}
376
377#[derive(Deserialize, Debug)]
378pub struct Usage {
379 pub completion_tokens: u64,
380 pub prompt_tokens: u64,
381 pub total_tokens: u64,
382}
383
384#[derive(Debug, Deserialize)]
385pub struct ResponseChoice {
386 pub index: Option<usize>,
387 pub finish_reason: Option<String>,
388 pub delta: Option<ResponseDelta>,
389 pub message: Option<ResponseDelta>,
390}
391
392#[derive(Debug, Deserialize)]
393pub struct ResponseDelta {
394 pub content: Option<String>,
395 pub role: Option<Role>,
396 #[serde(default)]
397 pub tool_calls: Vec<ToolCallChunk>,
398 pub reasoning_opaque: Option<String>,
399 pub reasoning_text: Option<String>,
400}
401#[derive(Deserialize, Debug, Eq, PartialEq)]
402pub struct ToolCallChunk {
403 pub index: Option<usize>,
404 pub id: Option<String>,
405 pub function: Option<FunctionChunk>,
406}
407
408#[derive(Deserialize, Debug, Eq, PartialEq)]
409pub struct FunctionChunk {
410 pub name: Option<String>,
411 pub arguments: Option<String>,
412 pub thought_signature: Option<String>,
413}
414
415struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
416
417impl Global for GlobalCopilotChat {}
418
419pub struct CopilotChat {
420 oauth_token: Option<String>,
421 api_endpoint: Option<String>,
422 configuration: CopilotChatConfiguration,
423 models: Option<Vec<Model>>,
424 client: Arc<dyn HttpClient>,
425}
426
427pub fn init(
428 fs: Arc<dyn Fs>,
429 client: Arc<dyn HttpClient>,
430 configuration: CopilotChatConfiguration,
431 cx: &mut App,
432) {
433 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
434 cx.set_global(GlobalCopilotChat(copilot_chat));
435}
436
437pub fn copilot_chat_config_dir() -> &'static PathBuf {
438 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
439
440 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
441 let config_dir = if cfg!(target_os = "windows") {
442 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
443 } else {
444 std::env::var("XDG_CONFIG_HOME")
445 .map(PathBuf::from)
446 .unwrap_or_else(|_| home_dir().join(".config"))
447 };
448
449 config_dir.join("github-copilot")
450 })
451}
452
453fn copilot_chat_config_paths() -> [PathBuf; 2] {
454 let base_dir = copilot_chat_config_dir();
455 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
456}
457
458impl CopilotChat {
459 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
460 cx.try_global::<GlobalCopilotChat>()
461 .map(|model| model.0.clone())
462 }
463
464 fn new(
465 fs: Arc<dyn Fs>,
466 client: Arc<dyn HttpClient>,
467 configuration: CopilotChatConfiguration,
468 cx: &mut Context<Self>,
469 ) -> Self {
470 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
471 let dir_path = copilot_chat_config_dir();
472
473 cx.spawn(async move |this, cx| {
474 let mut parent_watch_rx = watch_config_dir(
475 cx.background_executor(),
476 fs.clone(),
477 dir_path.clone(),
478 config_paths,
479 );
480 while let Some(contents) = parent_watch_rx.next().await {
481 let oauth_domain =
482 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
483 let oauth_token = extract_oauth_token(contents, &oauth_domain);
484
485 this.update(cx, |this, cx| {
486 this.oauth_token = oauth_token.clone();
487 cx.notify();
488 })?;
489
490 if oauth_token.is_some() {
491 Self::update_models(&this, cx).await?;
492 }
493 }
494 anyhow::Ok(())
495 })
496 .detach_and_log_err(cx);
497
498 let this = Self {
499 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
500 api_endpoint: None,
501 models: None,
502 configuration,
503 client,
504 };
505
506 if this.oauth_token.is_some() {
507 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
508 .detach_and_log_err(cx);
509 }
510
511 this
512 }
513
514 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
515 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
516 (
517 this.oauth_token.clone(),
518 this.client.clone(),
519 this.configuration.clone(),
520 )
521 })?;
522
523 let oauth_token = oauth_token
524 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
525
526 let api_endpoint =
527 Self::resolve_api_endpoint(&this, &oauth_token, &configuration, &client, cx).await?;
528
529 let models_url = configuration.models_url(&api_endpoint);
530 let models = get_models(models_url.into(), oauth_token, client.clone()).await?;
531
532 this.update(cx, |this, cx| {
533 this.models = Some(models);
534 cx.notify();
535 })?;
536 anyhow::Ok(())
537 }
538
539 pub fn is_authenticated(&self) -> bool {
540 self.oauth_token.is_some()
541 }
542
543 pub fn models(&self) -> Option<&[Model]> {
544 self.models.as_deref()
545 }
546
547 pub async fn stream_completion(
548 request: Request,
549 is_user_initiated: bool,
550 mut cx: AsyncApp,
551 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
552 let (client, oauth_token, api_endpoint, configuration) =
553 Self::get_auth_details(&mut cx).await?;
554
555 let api_url = configuration.chat_completions_url(&api_endpoint);
556 stream_completion(
557 client.clone(),
558 oauth_token,
559 api_url.into(),
560 request,
561 is_user_initiated,
562 )
563 .await
564 }
565
566 pub async fn stream_response(
567 request: responses::Request,
568 is_user_initiated: bool,
569 mut cx: AsyncApp,
570 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
571 let (client, oauth_token, api_endpoint, configuration) =
572 Self::get_auth_details(&mut cx).await?;
573
574 let api_url = configuration.responses_url(&api_endpoint);
575 responses::stream_response(
576 client.clone(),
577 oauth_token,
578 api_url,
579 request,
580 is_user_initiated,
581 )
582 .await
583 }
584
585 async fn get_auth_details(
586 cx: &mut AsyncApp,
587 ) -> Result<(
588 Arc<dyn HttpClient>,
589 String,
590 String,
591 CopilotChatConfiguration,
592 )> {
593 let this = cx
594 .update(|cx| Self::global(cx))
595 .context("Copilot chat is not enabled")?;
596
597 let (oauth_token, api_endpoint, client, configuration) = this.read_with(cx, |this, _| {
598 (
599 this.oauth_token.clone(),
600 this.api_endpoint.clone(),
601 this.client.clone(),
602 this.configuration.clone(),
603 )
604 });
605
606 let oauth_token = oauth_token.context("No OAuth token available")?;
607
608 let api_endpoint = match api_endpoint {
609 Some(endpoint) => endpoint,
610 None => {
611 let weak = this.downgrade();
612 Self::resolve_api_endpoint(&weak, &oauth_token, &configuration, &client, cx).await?
613 }
614 };
615
616 Ok((client, oauth_token, api_endpoint, configuration))
617 }
618
619 async fn resolve_api_endpoint(
620 this: &WeakEntity<Self>,
621 oauth_token: &str,
622 configuration: &CopilotChatConfiguration,
623 client: &Arc<dyn HttpClient>,
624 cx: &mut AsyncApp,
625 ) -> Result<String> {
626 let api_endpoint = match discover_api_endpoint(oauth_token, configuration, client).await {
627 Ok(endpoint) => endpoint,
628 Err(error) => {
629 log::warn!(
630 "Failed to discover Copilot API endpoint via GraphQL, \
631 falling back to {DEFAULT_COPILOT_API_ENDPOINT}: {error:#}"
632 );
633 DEFAULT_COPILOT_API_ENDPOINT.to_string()
634 }
635 };
636
637 this.update(cx, |this, cx| {
638 this.api_endpoint = Some(api_endpoint.clone());
639 cx.notify();
640 })?;
641
642 Ok(api_endpoint)
643 }
644
645 pub fn set_configuration(
646 &mut self,
647 configuration: CopilotChatConfiguration,
648 cx: &mut Context<Self>,
649 ) {
650 let same_configuration = self.configuration == configuration;
651 self.configuration = configuration;
652 if !same_configuration {
653 self.api_endpoint = None;
654 cx.spawn(async move |this, cx| {
655 Self::update_models(&this, cx).await?;
656 Ok::<_, anyhow::Error>(())
657 })
658 .detach();
659 }
660 }
661}
662
663async fn get_models(
664 models_url: Arc<str>,
665 oauth_token: String,
666 client: Arc<dyn HttpClient>,
667) -> Result<Vec<Model>> {
668 let all_models = request_models(models_url, oauth_token, client).await?;
669
670 let mut models: Vec<Model> = all_models
671 .into_iter()
672 .filter(|model| {
673 model.model_picker_enabled
674 && model.capabilities.model_type.as_str() == "chat"
675 && model
676 .policy
677 .as_ref()
678 .is_none_or(|policy| policy.state == "enabled")
679 })
680 .collect();
681
682 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
683 let default_model = models.remove(default_model_position);
684 models.insert(0, default_model);
685 }
686
687 Ok(models)
688}
689
690#[derive(Deserialize)]
691struct GraphQLResponse {
692 data: Option<GraphQLData>,
693}
694
695#[derive(Deserialize)]
696struct GraphQLData {
697 viewer: GraphQLViewer,
698}
699
700#[derive(Deserialize)]
701struct GraphQLViewer {
702 #[serde(rename = "copilotEndpoints")]
703 copilot_endpoints: GraphQLCopilotEndpoints,
704}
705
706#[derive(Deserialize)]
707struct GraphQLCopilotEndpoints {
708 api: String,
709}
710
711pub(crate) async fn discover_api_endpoint(
712 oauth_token: &str,
713 configuration: &CopilotChatConfiguration,
714 client: &Arc<dyn HttpClient>,
715) -> Result<String> {
716 let graphql_url = configuration.graphql_url();
717 let query = serde_json::json!({
718 "query": "query { viewer { copilotEndpoints { api } } }"
719 });
720
721 let request = HttpRequest::builder()
722 .method(Method::POST)
723 .uri(graphql_url.as_str())
724 .header("Authorization", format!("Bearer {}", oauth_token))
725 .header("Content-Type", "application/json")
726 .body(AsyncBody::from(serde_json::to_string(&query)?))?;
727
728 let mut response = client.send(request).await?;
729
730 anyhow::ensure!(
731 response.status().is_success(),
732 "GraphQL endpoint discovery failed: {}",
733 response.status()
734 );
735
736 let mut body = Vec::new();
737 response.body_mut().read_to_end(&mut body).await?;
738 let body_str = std::str::from_utf8(&body)?;
739
740 let parsed: GraphQLResponse = serde_json::from_str(body_str)
741 .context("Failed to parse GraphQL response for Copilot endpoint discovery")?;
742
743 let data = parsed
744 .data
745 .context("GraphQL response contained no data field")?;
746
747 Ok(data.viewer.copilot_endpoints.api)
748}
749
750pub(crate) fn copilot_request_headers(
751 builder: http_client::Builder,
752 oauth_token: &str,
753 is_user_initiated: Option<bool>,
754) -> http_client::Builder {
755 builder
756 .header("Authorization", format!("Bearer {}", oauth_token))
757 .header("Content-Type", "application/json")
758 .header(
759 "Editor-Version",
760 format!(
761 "Zed/{}",
762 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
763 ),
764 )
765 .when_some(is_user_initiated, |builder, is_user_initiated| {
766 builder.header(
767 "X-Initiator",
768 if is_user_initiated { "user" } else { "agent" },
769 )
770 })
771}
772
773async fn request_models(
774 models_url: Arc<str>,
775 oauth_token: String,
776 client: Arc<dyn HttpClient>,
777) -> Result<Vec<Model>> {
778 let request_builder = copilot_request_headers(
779 HttpRequest::builder()
780 .method(Method::GET)
781 .uri(models_url.as_ref()),
782 &oauth_token,
783 None,
784 )
785 .header("x-github-api-version", "2025-05-01");
786
787 let request = request_builder.body(AsyncBody::empty())?;
788
789 let mut response = client.send(request).await?;
790
791 anyhow::ensure!(
792 response.status().is_success(),
793 "Failed to request models: {}",
794 response.status()
795 );
796 let mut body = Vec::new();
797 response.body_mut().read_to_end(&mut body).await?;
798
799 let body_str = std::str::from_utf8(&body)?;
800
801 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
802
803 Ok(models)
804}
805
806fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
807 serde_json::from_str::<serde_json::Value>(&contents)
808 .map(|v| {
809 v.as_object().and_then(|obj| {
810 obj.iter().find_map(|(key, value)| {
811 if key.starts_with(domain) {
812 value["oauth_token"].as_str().map(|v| v.to_string())
813 } else {
814 None
815 }
816 })
817 })
818 })
819 .ok()
820 .flatten()
821}
822
823async fn stream_completion(
824 client: Arc<dyn HttpClient>,
825 oauth_token: String,
826 completion_url: Arc<str>,
827 request: Request,
828 is_user_initiated: bool,
829) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
830 let is_vision_request = request.messages.iter().any(|message| match message {
831 ChatMessage::User { content }
832 | ChatMessage::Assistant { content, .. }
833 | ChatMessage::Tool { content, .. } => {
834 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
835 }
836 _ => false,
837 });
838
839 let request_builder = copilot_request_headers(
840 HttpRequest::builder()
841 .method(Method::POST)
842 .uri(completion_url.as_ref()),
843 &oauth_token,
844 Some(is_user_initiated),
845 )
846 .when(is_vision_request, |builder| {
847 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
848 });
849
850 let is_streaming = request.stream;
851
852 let json = serde_json::to_string(&request)?;
853 let request = request_builder.body(AsyncBody::from(json))?;
854 let mut response = client.send(request).await?;
855
856 if !response.status().is_success() {
857 let mut body = Vec::new();
858 response.body_mut().read_to_end(&mut body).await?;
859 let body_str = std::str::from_utf8(&body)?;
860 anyhow::bail!(
861 "Failed to connect to API: {} {}",
862 response.status(),
863 body_str
864 );
865 }
866
867 if is_streaming {
868 let reader = BufReader::new(response.into_body());
869 Ok(reader
870 .lines()
871 .filter_map(|line| async move {
872 match line {
873 Ok(line) => {
874 let line = line.strip_prefix("data: ")?;
875 if line.starts_with("[DONE]") {
876 return None;
877 }
878
879 match serde_json::from_str::<ResponseEvent>(line) {
880 Ok(response) => {
881 if response.choices.is_empty() {
882 None
883 } else {
884 Some(Ok(response))
885 }
886 }
887 Err(error) => Some(Err(anyhow!(error))),
888 }
889 }
890 Err(error) => Some(Err(anyhow!(error))),
891 }
892 })
893 .boxed())
894 } else {
895 let mut body = Vec::new();
896 response.body_mut().read_to_end(&mut body).await?;
897 let body_str = std::str::from_utf8(&body)?;
898 let response: ResponseEvent = serde_json::from_str(body_str)?;
899
900 Ok(futures::stream::once(async move { Ok(response) }).boxed())
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907
908 #[test]
909 fn test_resilient_model_schema_deserialize() {
910 let json = r#"{
911 "data": [
912 {
913 "billing": {
914 "is_premium": false,
915 "multiplier": 0
916 },
917 "capabilities": {
918 "family": "gpt-4",
919 "limits": {
920 "max_context_window_tokens": 32768,
921 "max_output_tokens": 4096,
922 "max_prompt_tokens": 32768
923 },
924 "object": "model_capabilities",
925 "supports": { "streaming": true, "tool_calls": true },
926 "tokenizer": "cl100k_base",
927 "type": "chat"
928 },
929 "id": "gpt-4",
930 "is_chat_default": false,
931 "is_chat_fallback": false,
932 "model_picker_enabled": false,
933 "name": "GPT 4",
934 "object": "model",
935 "preview": false,
936 "vendor": "Azure OpenAI",
937 "version": "gpt-4-0613"
938 },
939 {
940 "some-unknown-field": 123
941 },
942 {
943 "billing": {
944 "is_premium": true,
945 "multiplier": 1,
946 "restricted_to": [
947 "pro",
948 "pro_plus",
949 "business",
950 "enterprise"
951 ]
952 },
953 "capabilities": {
954 "family": "claude-3.7-sonnet",
955 "limits": {
956 "max_context_window_tokens": 200000,
957 "max_output_tokens": 16384,
958 "max_prompt_tokens": 90000,
959 "vision": {
960 "max_prompt_image_size": 3145728,
961 "max_prompt_images": 1,
962 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
963 }
964 },
965 "object": "model_capabilities",
966 "supports": {
967 "parallel_tool_calls": true,
968 "streaming": true,
969 "tool_calls": true,
970 "vision": true
971 },
972 "tokenizer": "o200k_base",
973 "type": "chat"
974 },
975 "id": "claude-3.7-sonnet",
976 "is_chat_default": false,
977 "is_chat_fallback": false,
978 "model_picker_enabled": true,
979 "name": "Claude 3.7 Sonnet",
980 "object": "model",
981 "policy": {
982 "state": "enabled",
983 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
984 },
985 "preview": false,
986 "vendor": "Anthropic",
987 "version": "claude-3.7-sonnet"
988 }
989 ],
990 "object": "list"
991 }"#;
992
993 let schema: ModelSchema = serde_json::from_str(json).unwrap();
994
995 assert_eq!(schema.data.len(), 2);
996 assert_eq!(schema.data[0].id, "gpt-4");
997 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
998 }
999
1000 #[test]
1001 fn test_unknown_vendor_resilience() {
1002 let json = r#"{
1003 "data": [
1004 {
1005 "billing": {
1006 "is_premium": false,
1007 "multiplier": 1
1008 },
1009 "capabilities": {
1010 "family": "future-model",
1011 "limits": {
1012 "max_context_window_tokens": 128000,
1013 "max_output_tokens": 8192,
1014 "max_prompt_tokens": 120000
1015 },
1016 "object": "model_capabilities",
1017 "supports": { "streaming": true, "tool_calls": true },
1018 "type": "chat"
1019 },
1020 "id": "future-model-v1",
1021 "is_chat_default": false,
1022 "is_chat_fallback": false,
1023 "model_picker_enabled": true,
1024 "name": "Future Model v1",
1025 "object": "model",
1026 "preview": false,
1027 "vendor": "SomeNewVendor",
1028 "version": "v1.0"
1029 }
1030 ],
1031 "object": "list"
1032 }"#;
1033
1034 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1035
1036 assert_eq!(schema.data.len(), 1);
1037 assert_eq!(schema.data[0].id, "future-model-v1");
1038 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1039 }
1040
1041 #[test]
1042 fn test_max_token_count_returns_context_window_not_prompt_tokens() {
1043 let json = r#"{
1044 "data": [
1045 {
1046 "billing": { "is_premium": true, "multiplier": 1 },
1047 "capabilities": {
1048 "family": "claude-sonnet-4",
1049 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1050 "object": "model_capabilities",
1051 "supports": { "streaming": true, "tool_calls": true },
1052 "type": "chat"
1053 },
1054 "id": "claude-sonnet-4",
1055 "is_chat_default": false,
1056 "is_chat_fallback": false,
1057 "model_picker_enabled": true,
1058 "name": "Claude Sonnet 4",
1059 "object": "model",
1060 "preview": false,
1061 "vendor": "Anthropic",
1062 "version": "claude-sonnet-4"
1063 },
1064 {
1065 "billing": { "is_premium": false, "multiplier": 1 },
1066 "capabilities": {
1067 "family": "gpt-4o",
1068 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1069 "object": "model_capabilities",
1070 "supports": { "streaming": true, "tool_calls": true },
1071 "type": "chat"
1072 },
1073 "id": "gpt-4o",
1074 "is_chat_default": true,
1075 "is_chat_fallback": false,
1076 "model_picker_enabled": true,
1077 "name": "GPT-4o",
1078 "object": "model",
1079 "preview": false,
1080 "vendor": "Azure OpenAI",
1081 "version": "gpt-4o"
1082 }
1083 ],
1084 "object": "list"
1085 }"#;
1086
1087 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1088
1089 // max_token_count() should return context window (200000), not prompt tokens (90000)
1090 assert_eq!(schema.data[0].max_token_count(), 200000);
1091
1092 // GPT-4o should return 128000 (context window), not 110000 (prompt tokens)
1093 assert_eq!(schema.data[1].max_token_count(), 128000);
1094 }
1095
1096 #[test]
1097 fn test_models_with_pending_policy_deserialize() {
1098 // This test verifies that models with policy states other than "enabled"
1099 // (such as "pending" or "requires_consent") are properly deserialized.
1100 // Note: These models will be filtered out by get_models() and won't appear
1101 // in the model picker until the user enables them on GitHub.
1102 let json = r#"{
1103 "data": [
1104 {
1105 "billing": { "is_premium": true, "multiplier": 1 },
1106 "capabilities": {
1107 "family": "claude-sonnet-4",
1108 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1109 "object": "model_capabilities",
1110 "supports": { "streaming": true, "tool_calls": true },
1111 "type": "chat"
1112 },
1113 "id": "claude-sonnet-4",
1114 "is_chat_default": false,
1115 "is_chat_fallback": false,
1116 "model_picker_enabled": true,
1117 "name": "Claude Sonnet 4",
1118 "object": "model",
1119 "policy": {
1120 "state": "pending",
1121 "terms": "Enable access to Claude models from Anthropic."
1122 },
1123 "preview": false,
1124 "vendor": "Anthropic",
1125 "version": "claude-sonnet-4"
1126 },
1127 {
1128 "billing": { "is_premium": true, "multiplier": 1 },
1129 "capabilities": {
1130 "family": "claude-opus-4",
1131 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1132 "object": "model_capabilities",
1133 "supports": { "streaming": true, "tool_calls": true },
1134 "type": "chat"
1135 },
1136 "id": "claude-opus-4",
1137 "is_chat_default": false,
1138 "is_chat_fallback": false,
1139 "model_picker_enabled": true,
1140 "name": "Claude Opus 4",
1141 "object": "model",
1142 "policy": {
1143 "state": "requires_consent",
1144 "terms": "Enable access to Claude models from Anthropic."
1145 },
1146 "preview": false,
1147 "vendor": "Anthropic",
1148 "version": "claude-opus-4"
1149 }
1150 ],
1151 "object": "list"
1152 }"#;
1153
1154 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1155
1156 // Both models should deserialize successfully (filtering happens in get_models)
1157 assert_eq!(schema.data.len(), 2);
1158 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1159 assert_eq!(schema.data[1].id, "claude-opus-4");
1160 }
1161
1162 #[test]
1163 fn test_multiple_anthropic_models_preserved() {
1164 // This test verifies that multiple Claude models from Anthropic
1165 // are all preserved and not incorrectly deduplicated.
1166 // This was the root cause of issue #47540.
1167 let json = r#"{
1168 "data": [
1169 {
1170 "billing": { "is_premium": true, "multiplier": 1 },
1171 "capabilities": {
1172 "family": "claude-sonnet-4",
1173 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1174 "object": "model_capabilities",
1175 "supports": { "streaming": true, "tool_calls": true },
1176 "type": "chat"
1177 },
1178 "id": "claude-sonnet-4",
1179 "is_chat_default": false,
1180 "is_chat_fallback": false,
1181 "model_picker_enabled": true,
1182 "name": "Claude Sonnet 4",
1183 "object": "model",
1184 "preview": false,
1185 "vendor": "Anthropic",
1186 "version": "claude-sonnet-4"
1187 },
1188 {
1189 "billing": { "is_premium": true, "multiplier": 1 },
1190 "capabilities": {
1191 "family": "claude-opus-4",
1192 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1193 "object": "model_capabilities",
1194 "supports": { "streaming": true, "tool_calls": true },
1195 "type": "chat"
1196 },
1197 "id": "claude-opus-4",
1198 "is_chat_default": false,
1199 "is_chat_fallback": false,
1200 "model_picker_enabled": true,
1201 "name": "Claude Opus 4",
1202 "object": "model",
1203 "preview": false,
1204 "vendor": "Anthropic",
1205 "version": "claude-opus-4"
1206 },
1207 {
1208 "billing": { "is_premium": true, "multiplier": 1 },
1209 "capabilities": {
1210 "family": "claude-sonnet-4.5",
1211 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1212 "object": "model_capabilities",
1213 "supports": { "streaming": true, "tool_calls": true },
1214 "type": "chat"
1215 },
1216 "id": "claude-sonnet-4.5",
1217 "is_chat_default": false,
1218 "is_chat_fallback": false,
1219 "model_picker_enabled": true,
1220 "name": "Claude Sonnet 4.5",
1221 "object": "model",
1222 "preview": false,
1223 "vendor": "Anthropic",
1224 "version": "claude-sonnet-4.5"
1225 }
1226 ],
1227 "object": "list"
1228 }"#;
1229
1230 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1231
1232 // All three Anthropic models should be preserved
1233 assert_eq!(schema.data.len(), 3);
1234 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1235 assert_eq!(schema.data[1].id, "claude-opus-4");
1236 assert_eq!(schema.data[2].id, "claude-sonnet-4.5");
1237 }
1238
1239 #[test]
1240 fn test_models_with_same_family_both_preserved() {
1241 // Test that models sharing the same family (e.g., thinking variants)
1242 // are both preserved in the model list.
1243 let json = r#"{
1244 "data": [
1245 {
1246 "billing": { "is_premium": true, "multiplier": 1 },
1247 "capabilities": {
1248 "family": "claude-sonnet-4",
1249 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1250 "object": "model_capabilities",
1251 "supports": { "streaming": true, "tool_calls": true },
1252 "type": "chat"
1253 },
1254 "id": "claude-sonnet-4",
1255 "is_chat_default": false,
1256 "is_chat_fallback": false,
1257 "model_picker_enabled": true,
1258 "name": "Claude Sonnet 4",
1259 "object": "model",
1260 "preview": false,
1261 "vendor": "Anthropic",
1262 "version": "claude-sonnet-4"
1263 },
1264 {
1265 "billing": { "is_premium": true, "multiplier": 1 },
1266 "capabilities": {
1267 "family": "claude-sonnet-4",
1268 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1269 "object": "model_capabilities",
1270 "supports": { "streaming": true, "tool_calls": true },
1271 "type": "chat"
1272 },
1273 "id": "claude-sonnet-4-thinking",
1274 "is_chat_default": false,
1275 "is_chat_fallback": false,
1276 "model_picker_enabled": true,
1277 "name": "Claude Sonnet 4 (Thinking)",
1278 "object": "model",
1279 "preview": false,
1280 "vendor": "Anthropic",
1281 "version": "claude-sonnet-4-thinking"
1282 }
1283 ],
1284 "object": "list"
1285 }"#;
1286
1287 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1288
1289 // Both models should be preserved even though they share the same family
1290 assert_eq!(schema.data.len(), 2);
1291 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1292 assert_eq!(schema.data[1].id, "claude-sonnet-4-thinking");
1293 }
1294
1295 #[test]
1296 fn test_mixed_vendor_models_all_preserved() {
1297 // Test that models from different vendors are all preserved.
1298 let json = r#"{
1299 "data": [
1300 {
1301 "billing": { "is_premium": false, "multiplier": 1 },
1302 "capabilities": {
1303 "family": "gpt-4o",
1304 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1305 "object": "model_capabilities",
1306 "supports": { "streaming": true, "tool_calls": true },
1307 "type": "chat"
1308 },
1309 "id": "gpt-4o",
1310 "is_chat_default": true,
1311 "is_chat_fallback": false,
1312 "model_picker_enabled": true,
1313 "name": "GPT-4o",
1314 "object": "model",
1315 "preview": false,
1316 "vendor": "Azure OpenAI",
1317 "version": "gpt-4o"
1318 },
1319 {
1320 "billing": { "is_premium": true, "multiplier": 1 },
1321 "capabilities": {
1322 "family": "claude-sonnet-4",
1323 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1324 "object": "model_capabilities",
1325 "supports": { "streaming": true, "tool_calls": true },
1326 "type": "chat"
1327 },
1328 "id": "claude-sonnet-4",
1329 "is_chat_default": false,
1330 "is_chat_fallback": false,
1331 "model_picker_enabled": true,
1332 "name": "Claude Sonnet 4",
1333 "object": "model",
1334 "preview": false,
1335 "vendor": "Anthropic",
1336 "version": "claude-sonnet-4"
1337 },
1338 {
1339 "billing": { "is_premium": true, "multiplier": 1 },
1340 "capabilities": {
1341 "family": "gemini-2.0-flash",
1342 "limits": { "max_context_window_tokens": 1000000, "max_output_tokens": 8192, "max_prompt_tokens": 900000 },
1343 "object": "model_capabilities",
1344 "supports": { "streaming": true, "tool_calls": true },
1345 "type": "chat"
1346 },
1347 "id": "gemini-2.0-flash",
1348 "is_chat_default": false,
1349 "is_chat_fallback": false,
1350 "model_picker_enabled": true,
1351 "name": "Gemini 2.0 Flash",
1352 "object": "model",
1353 "preview": false,
1354 "vendor": "Google",
1355 "version": "gemini-2.0-flash"
1356 }
1357 ],
1358 "object": "list"
1359 }"#;
1360
1361 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1362
1363 // All three models from different vendors should be preserved
1364 assert_eq!(schema.data.len(), 3);
1365 assert_eq!(schema.data[0].id, "gpt-4o");
1366 assert_eq!(schema.data[1].id, "claude-sonnet-4");
1367 assert_eq!(schema.data[2].id, "gemini-2.0-flash");
1368 }
1369
1370 #[test]
1371 fn test_model_with_messages_endpoint_deserializes() {
1372 // Anthropic Claude models use /v1/messages endpoint.
1373 // This test verifies such models deserialize correctly (issue #47540 root cause).
1374 let json = r#"{
1375 "data": [
1376 {
1377 "billing": { "is_premium": true, "multiplier": 1 },
1378 "capabilities": {
1379 "family": "claude-sonnet-4",
1380 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1381 "object": "model_capabilities",
1382 "supports": { "streaming": true, "tool_calls": true },
1383 "type": "chat"
1384 },
1385 "id": "claude-sonnet-4",
1386 "is_chat_default": false,
1387 "is_chat_fallback": false,
1388 "model_picker_enabled": true,
1389 "name": "Claude Sonnet 4",
1390 "object": "model",
1391 "preview": false,
1392 "vendor": "Anthropic",
1393 "version": "claude-sonnet-4",
1394 "supported_endpoints": ["/v1/messages"]
1395 }
1396 ],
1397 "object": "list"
1398 }"#;
1399
1400 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1401
1402 assert_eq!(schema.data.len(), 1);
1403 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1404 assert_eq!(
1405 schema.data[0].supported_endpoints,
1406 vec![ModelSupportedEndpoint::Messages]
1407 );
1408 }
1409
1410 #[test]
1411 fn test_model_with_unknown_endpoint_deserializes() {
1412 // Future-proofing: unknown endpoints should deserialize to Unknown variant
1413 // instead of causing the entire model to fail deserialization.
1414 let json = r#"{
1415 "data": [
1416 {
1417 "billing": { "is_premium": false, "multiplier": 1 },
1418 "capabilities": {
1419 "family": "future-model",
1420 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 8192, "max_prompt_tokens": 120000 },
1421 "object": "model_capabilities",
1422 "supports": { "streaming": true, "tool_calls": true },
1423 "type": "chat"
1424 },
1425 "id": "future-model-v2",
1426 "is_chat_default": false,
1427 "is_chat_fallback": false,
1428 "model_picker_enabled": true,
1429 "name": "Future Model v2",
1430 "object": "model",
1431 "preview": false,
1432 "vendor": "OpenAI",
1433 "version": "v2.0",
1434 "supported_endpoints": ["/v2/completions", "/chat/completions"]
1435 }
1436 ],
1437 "object": "list"
1438 }"#;
1439
1440 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1441
1442 assert_eq!(schema.data.len(), 1);
1443 assert_eq!(schema.data[0].id, "future-model-v2");
1444 assert_eq!(
1445 schema.data[0].supported_endpoints,
1446 vec![
1447 ModelSupportedEndpoint::Unknown,
1448 ModelSupportedEndpoint::ChatCompletions
1449 ]
1450 );
1451 }
1452
1453 #[test]
1454 fn test_model_with_multiple_endpoints() {
1455 // Test model with multiple supported endpoints (common for newer models).
1456 let json = r#"{
1457 "data": [
1458 {
1459 "billing": { "is_premium": true, "multiplier": 1 },
1460 "capabilities": {
1461 "family": "gpt-4o",
1462 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1463 "object": "model_capabilities",
1464 "supports": { "streaming": true, "tool_calls": true },
1465 "type": "chat"
1466 },
1467 "id": "gpt-4o",
1468 "is_chat_default": true,
1469 "is_chat_fallback": false,
1470 "model_picker_enabled": true,
1471 "name": "GPT-4o",
1472 "object": "model",
1473 "preview": false,
1474 "vendor": "OpenAI",
1475 "version": "gpt-4o",
1476 "supported_endpoints": ["/chat/completions", "/responses"]
1477 }
1478 ],
1479 "object": "list"
1480 }"#;
1481
1482 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1483
1484 assert_eq!(schema.data.len(), 1);
1485 assert_eq!(schema.data[0].id, "gpt-4o");
1486 assert_eq!(
1487 schema.data[0].supported_endpoints,
1488 vec![
1489 ModelSupportedEndpoint::ChatCompletions,
1490 ModelSupportedEndpoint::Responses
1491 ]
1492 );
1493 }
1494
1495 #[test]
1496 fn test_supports_response_method() {
1497 // Test the supports_response() method which determines endpoint routing.
1498 let model_with_responses_only = Model {
1499 billing: ModelBilling {
1500 is_premium: false,
1501 multiplier: 1.0,
1502 restricted_to: None,
1503 },
1504 capabilities: ModelCapabilities {
1505 family: "test".to_string(),
1506 limits: ModelLimits::default(),
1507 supports: ModelSupportedFeatures {
1508 streaming: true,
1509 tool_calls: true,
1510 parallel_tool_calls: false,
1511 vision: false,
1512 },
1513 model_type: "chat".to_string(),
1514 tokenizer: None,
1515 },
1516 id: "test-model".to_string(),
1517 name: "Test Model".to_string(),
1518 policy: None,
1519 vendor: ModelVendor::OpenAI,
1520 is_chat_default: false,
1521 is_chat_fallback: false,
1522 model_picker_enabled: true,
1523 supported_endpoints: vec![ModelSupportedEndpoint::Responses],
1524 };
1525
1526 let model_with_chat_completions = Model {
1527 supported_endpoints: vec![ModelSupportedEndpoint::ChatCompletions],
1528 ..model_with_responses_only.clone()
1529 };
1530
1531 let model_with_both = Model {
1532 supported_endpoints: vec![
1533 ModelSupportedEndpoint::ChatCompletions,
1534 ModelSupportedEndpoint::Responses,
1535 ],
1536 ..model_with_responses_only.clone()
1537 };
1538
1539 let model_with_messages = Model {
1540 supported_endpoints: vec![ModelSupportedEndpoint::Messages],
1541 ..model_with_responses_only.clone()
1542 };
1543
1544 // Only /responses endpoint -> supports_response = true
1545 assert!(model_with_responses_only.supports_response());
1546
1547 // Only /chat/completions endpoint -> supports_response = false
1548 assert!(!model_with_chat_completions.supports_response());
1549
1550 // Both endpoints (has /chat/completions) -> supports_response = false
1551 assert!(!model_with_both.supports_response());
1552
1553 // Only /v1/messages endpoint -> supports_response = false (doesn't have /responses)
1554 assert!(!model_with_messages.supports_response());
1555 }
1556}