1pub mod responses;
2
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::sync::OnceLock;
6
7use anyhow::Context as _;
8use anyhow::{Result, anyhow};
9use collections::HashSet;
10use fs::Fs;
11use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
12use gpui::WeakEntity;
13use gpui::{App, AsyncApp, Global, prelude::*};
14use http_client::HttpRequestExt;
15use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
16use paths::home_dir;
17use serde::{Deserialize, Serialize};
18
19use settings::watch_config_dir;
20
21pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
22const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
23
24#[derive(Default, Clone, Debug, PartialEq)]
25pub struct CopilotChatConfiguration {
26 pub enterprise_uri: Option<String>,
27}
28
29impl CopilotChatConfiguration {
30 pub fn oauth_domain(&self) -> String {
31 if let Some(enterprise_uri) = &self.enterprise_uri {
32 Self::parse_domain(enterprise_uri)
33 } else {
34 "github.com".to_string()
35 }
36 }
37
38 pub fn graphql_url(&self) -> String {
39 if let Some(enterprise_uri) = &self.enterprise_uri {
40 let domain = Self::parse_domain(enterprise_uri);
41 format!("https://{}/api/graphql", domain)
42 } else {
43 "https://api.github.com/graphql".to_string()
44 }
45 }
46
47 pub fn chat_completions_url(&self, api_endpoint: &str) -> String {
48 format!("{}/chat/completions", api_endpoint)
49 }
50
51 pub fn responses_url(&self, api_endpoint: &str) -> String {
52 format!("{}/responses", api_endpoint)
53 }
54
55 pub fn models_url(&self, api_endpoint: &str) -> String {
56 format!("{}/models", api_endpoint)
57 }
58
59 fn parse_domain(enterprise_uri: &str) -> String {
60 let uri = enterprise_uri.trim_end_matches('/');
61
62 if let Some(domain) = uri.strip_prefix("https://") {
63 domain.split('/').next().unwrap_or(domain).to_string()
64 } else if let Some(domain) = uri.strip_prefix("http://") {
65 domain.split('/').next().unwrap_or(domain).to_string()
66 } else {
67 uri.split('/').next().unwrap_or(uri).to_string()
68 }
69 }
70}
71
72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
73#[serde(rename_all = "lowercase")]
74pub enum Role {
75 User,
76 Assistant,
77 System,
78}
79
80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
81pub enum ModelSupportedEndpoint {
82 #[serde(rename = "/chat/completions")]
83 ChatCompletions,
84 #[serde(rename = "/responses")]
85 Responses,
86 #[serde(rename = "/v1/messages")]
87 Messages,
88 /// Unknown endpoint that we don't explicitly support yet
89 #[serde(other)]
90 Unknown,
91}
92
93#[derive(Deserialize)]
94struct ModelSchema {
95 #[serde(deserialize_with = "deserialize_models_skip_errors")]
96 data: Vec<Model>,
97}
98
99fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
100where
101 D: serde::Deserializer<'de>,
102{
103 let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
104 let models = raw_values
105 .into_iter()
106 .filter_map(|value| match serde_json::from_value::<Model>(value) {
107 Ok(model) => Some(model),
108 Err(err) => {
109 log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
110 None
111 }
112 })
113 .collect();
114
115 Ok(models)
116}
117
118#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
119pub struct Model {
120 billing: ModelBilling,
121 capabilities: ModelCapabilities,
122 id: String,
123 name: String,
124 policy: Option<ModelPolicy>,
125 vendor: ModelVendor,
126 is_chat_default: bool,
127 // The model with this value true is selected by VSCode copilot if a premium request limit is
128 // reached. Zed does not currently implement this behaviour
129 is_chat_fallback: bool,
130 model_picker_enabled: bool,
131 #[serde(default)]
132 supported_endpoints: Vec<ModelSupportedEndpoint>,
133}
134
135#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
136struct ModelBilling {
137 is_premium: bool,
138 multiplier: f64,
139 // List of plans a model is restricted to
140 // Field is not present if a model is available for all plans
141 #[serde(default)]
142 restricted_to: Option<Vec<String>>,
143}
144
145#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
146struct ModelCapabilities {
147 family: String,
148 #[serde(default)]
149 limits: ModelLimits,
150 supports: ModelSupportedFeatures,
151 #[serde(rename = "type")]
152 model_type: String,
153 #[serde(default)]
154 tokenizer: Option<String>,
155}
156
157#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
158struct ModelLimits {
159 #[serde(default)]
160 max_context_window_tokens: usize,
161 #[serde(default)]
162 max_output_tokens: usize,
163 #[serde(default)]
164 max_prompt_tokens: u64,
165}
166
167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
168struct ModelPolicy {
169 state: String,
170}
171
172#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
173struct ModelSupportedFeatures {
174 #[serde(default)]
175 streaming: bool,
176 #[serde(default)]
177 tool_calls: bool,
178 #[serde(default)]
179 parallel_tool_calls: bool,
180 #[serde(default)]
181 vision: bool,
182}
183
184#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
185pub enum ModelVendor {
186 // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
187 #[serde(alias = "Azure OpenAI")]
188 OpenAI,
189 Google,
190 Anthropic,
191 #[serde(rename = "xAI")]
192 XAI,
193 /// Unknown vendor that we don't explicitly support yet
194 #[serde(other)]
195 Unknown,
196}
197
198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
199#[serde(tag = "type")]
200pub enum ChatMessagePart {
201 #[serde(rename = "text")]
202 Text { text: String },
203 #[serde(rename = "image_url")]
204 Image { image_url: ImageUrl },
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
208pub struct ImageUrl {
209 pub url: String,
210}
211
212impl Model {
213 pub fn uses_streaming(&self) -> bool {
214 self.capabilities.supports.streaming
215 }
216
217 pub fn id(&self) -> &str {
218 self.id.as_str()
219 }
220
221 pub fn display_name(&self) -> &str {
222 self.name.as_str()
223 }
224
225 pub fn max_token_count(&self) -> u64 {
226 self.capabilities.limits.max_context_window_tokens as u64
227 }
228
229 pub fn supports_tools(&self) -> bool {
230 self.capabilities.supports.tool_calls
231 }
232
233 pub fn vendor(&self) -> ModelVendor {
234 self.vendor
235 }
236
237 pub fn supports_vision(&self) -> bool {
238 self.capabilities.supports.vision
239 }
240
241 pub fn supports_parallel_tool_calls(&self) -> bool {
242 self.capabilities.supports.parallel_tool_calls
243 }
244
245 pub fn tokenizer(&self) -> Option<&str> {
246 self.capabilities.tokenizer.as_deref()
247 }
248
249 pub fn supports_response(&self) -> bool {
250 self.supported_endpoints.len() > 0
251 && !self
252 .supported_endpoints
253 .contains(&ModelSupportedEndpoint::ChatCompletions)
254 && self
255 .supported_endpoints
256 .contains(&ModelSupportedEndpoint::Responses)
257 }
258
259 pub fn multiplier(&self) -> f64 {
260 self.billing.multiplier
261 }
262}
263
264#[derive(Serialize, Deserialize)]
265pub struct Request {
266 pub intent: bool,
267 pub n: usize,
268 pub stream: bool,
269 pub temperature: f32,
270 pub model: String,
271 pub messages: Vec<ChatMessage>,
272 #[serde(default, skip_serializing_if = "Vec::is_empty")]
273 pub tools: Vec<Tool>,
274 #[serde(default, skip_serializing_if = "Option::is_none")]
275 pub tool_choice: Option<ToolChoice>,
276}
277
278#[derive(Serialize, Deserialize)]
279pub struct Function {
280 pub name: String,
281 pub description: String,
282 pub parameters: serde_json::Value,
283}
284
285#[derive(Serialize, Deserialize)]
286#[serde(tag = "type", rename_all = "snake_case")]
287pub enum Tool {
288 Function { function: Function },
289}
290
291#[derive(Serialize, Deserialize, Debug)]
292#[serde(rename_all = "lowercase")]
293pub enum ToolChoice {
294 Auto,
295 Any,
296 None,
297}
298
299#[derive(Serialize, Deserialize, Debug)]
300#[serde(tag = "role", rename_all = "lowercase")]
301pub enum ChatMessage {
302 Assistant {
303 content: ChatMessageContent,
304 #[serde(default, skip_serializing_if = "Vec::is_empty")]
305 tool_calls: Vec<ToolCall>,
306 #[serde(default, skip_serializing_if = "Option::is_none")]
307 reasoning_opaque: Option<String>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
309 reasoning_text: Option<String>,
310 },
311 User {
312 content: ChatMessageContent,
313 },
314 System {
315 content: String,
316 },
317 Tool {
318 content: ChatMessageContent,
319 tool_call_id: String,
320 },
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324#[serde(untagged)]
325pub enum ChatMessageContent {
326 Plain(String),
327 Multipart(Vec<ChatMessagePart>),
328}
329
330impl ChatMessageContent {
331 pub fn empty() -> Self {
332 ChatMessageContent::Multipart(vec![])
333 }
334}
335
336impl From<Vec<ChatMessagePart>> for ChatMessageContent {
337 fn from(mut parts: Vec<ChatMessagePart>) -> Self {
338 if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
339 ChatMessageContent::Plain(std::mem::take(text))
340 } else {
341 ChatMessageContent::Multipart(parts)
342 }
343 }
344}
345
346impl From<String> for ChatMessageContent {
347 fn from(text: String) -> Self {
348 ChatMessageContent::Plain(text)
349 }
350}
351
352#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
353pub struct ToolCall {
354 pub id: String,
355 #[serde(flatten)]
356 pub content: ToolCallContent,
357}
358
359#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
360#[serde(tag = "type", rename_all = "lowercase")]
361pub enum ToolCallContent {
362 Function { function: FunctionContent },
363}
364
365#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
366pub struct FunctionContent {
367 pub name: String,
368 pub arguments: String,
369 #[serde(default, skip_serializing_if = "Option::is_none")]
370 pub thought_signature: Option<String>,
371}
372
373#[derive(Deserialize, Debug)]
374#[serde(tag = "type", rename_all = "snake_case")]
375pub struct ResponseEvent {
376 pub choices: Vec<ResponseChoice>,
377 pub id: String,
378 pub usage: Option<Usage>,
379}
380
381#[derive(Deserialize, Debug)]
382pub struct Usage {
383 pub completion_tokens: u64,
384 pub prompt_tokens: u64,
385 pub total_tokens: u64,
386}
387
388#[derive(Debug, Deserialize)]
389pub struct ResponseChoice {
390 pub index: Option<usize>,
391 pub finish_reason: Option<String>,
392 pub delta: Option<ResponseDelta>,
393 pub message: Option<ResponseDelta>,
394}
395
396#[derive(Debug, Deserialize)]
397pub struct ResponseDelta {
398 pub content: Option<String>,
399 pub role: Option<Role>,
400 #[serde(default)]
401 pub tool_calls: Vec<ToolCallChunk>,
402 pub reasoning_opaque: Option<String>,
403 pub reasoning_text: Option<String>,
404}
405#[derive(Deserialize, Debug, Eq, PartialEq)]
406pub struct ToolCallChunk {
407 pub index: Option<usize>,
408 pub id: Option<String>,
409 pub function: Option<FunctionChunk>,
410}
411
412#[derive(Deserialize, Debug, Eq, PartialEq)]
413pub struct FunctionChunk {
414 pub name: Option<String>,
415 pub arguments: Option<String>,
416 pub thought_signature: Option<String>,
417}
418
419struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
420
421impl Global for GlobalCopilotChat {}
422
423pub struct CopilotChat {
424 oauth_token: Option<String>,
425 api_endpoint: Option<String>,
426 configuration: CopilotChatConfiguration,
427 models: Option<Vec<Model>>,
428 client: Arc<dyn HttpClient>,
429}
430
431pub fn init(
432 fs: Arc<dyn Fs>,
433 client: Arc<dyn HttpClient>,
434 configuration: CopilotChatConfiguration,
435 cx: &mut App,
436) {
437 let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
438 cx.set_global(GlobalCopilotChat(copilot_chat));
439}
440
441pub fn copilot_chat_config_dir() -> &'static PathBuf {
442 static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
443
444 COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
445 let config_dir = if cfg!(target_os = "windows") {
446 dirs::data_local_dir().expect("failed to determine LocalAppData directory")
447 } else {
448 std::env::var("XDG_CONFIG_HOME")
449 .map(PathBuf::from)
450 .unwrap_or_else(|_| home_dir().join(".config"))
451 };
452
453 config_dir.join("github-copilot")
454 })
455}
456
457fn copilot_chat_config_paths() -> [PathBuf; 2] {
458 let base_dir = copilot_chat_config_dir();
459 [base_dir.join("hosts.json"), base_dir.join("apps.json")]
460}
461
462impl CopilotChat {
463 pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
464 cx.try_global::<GlobalCopilotChat>()
465 .map(|model| model.0.clone())
466 }
467
468 fn new(
469 fs: Arc<dyn Fs>,
470 client: Arc<dyn HttpClient>,
471 configuration: CopilotChatConfiguration,
472 cx: &mut Context<Self>,
473 ) -> Self {
474 let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
475 let dir_path = copilot_chat_config_dir();
476
477 cx.spawn(async move |this, cx| {
478 let mut parent_watch_rx = watch_config_dir(
479 cx.background_executor(),
480 fs.clone(),
481 dir_path.clone(),
482 config_paths,
483 );
484 while let Some(contents) = parent_watch_rx.next().await {
485 let oauth_domain =
486 this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
487 let oauth_token = extract_oauth_token(contents, &oauth_domain);
488
489 this.update(cx, |this, cx| {
490 this.oauth_token = oauth_token.clone();
491 cx.notify();
492 })?;
493
494 if oauth_token.is_some() {
495 Self::update_models(&this, cx).await?;
496 }
497 }
498 anyhow::Ok(())
499 })
500 .detach_and_log_err(cx);
501
502 let this = Self {
503 oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
504 api_endpoint: None,
505 models: None,
506 configuration,
507 client,
508 };
509
510 if this.oauth_token.is_some() {
511 cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
512 .detach_and_log_err(cx);
513 }
514
515 this
516 }
517
518 async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
519 let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
520 (
521 this.oauth_token.clone(),
522 this.client.clone(),
523 this.configuration.clone(),
524 )
525 })?;
526
527 let oauth_token = oauth_token
528 .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
529
530 let api_endpoint =
531 Self::resolve_api_endpoint(&this, &oauth_token, &configuration, &client, cx).await?;
532
533 let models_url = configuration.models_url(&api_endpoint);
534 let models = get_models(models_url.into(), oauth_token, client.clone()).await?;
535
536 this.update(cx, |this, cx| {
537 this.models = Some(models);
538 cx.notify();
539 })?;
540 anyhow::Ok(())
541 }
542
543 pub fn is_authenticated(&self) -> bool {
544 self.oauth_token.is_some()
545 }
546
547 pub fn models(&self) -> Option<&[Model]> {
548 self.models.as_deref()
549 }
550
551 pub async fn stream_completion(
552 request: Request,
553 is_user_initiated: bool,
554 mut cx: AsyncApp,
555 ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
556 let (client, oauth_token, api_endpoint, configuration) =
557 Self::get_auth_details(&mut cx).await?;
558
559 let api_url = configuration.chat_completions_url(&api_endpoint);
560 stream_completion(
561 client.clone(),
562 oauth_token,
563 api_url.into(),
564 request,
565 is_user_initiated,
566 )
567 .await
568 }
569
570 pub async fn stream_response(
571 request: responses::Request,
572 is_user_initiated: bool,
573 mut cx: AsyncApp,
574 ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
575 let (client, oauth_token, api_endpoint, configuration) =
576 Self::get_auth_details(&mut cx).await?;
577
578 let api_url = configuration.responses_url(&api_endpoint);
579 responses::stream_response(
580 client.clone(),
581 oauth_token,
582 api_url,
583 request,
584 is_user_initiated,
585 )
586 .await
587 }
588
589 async fn get_auth_details(
590 cx: &mut AsyncApp,
591 ) -> Result<(
592 Arc<dyn HttpClient>,
593 String,
594 String,
595 CopilotChatConfiguration,
596 )> {
597 let this = cx
598 .update(|cx| Self::global(cx))
599 .context("Copilot chat is not enabled")?;
600
601 let (oauth_token, api_endpoint, client, configuration) = this.read_with(cx, |this, _| {
602 (
603 this.oauth_token.clone(),
604 this.api_endpoint.clone(),
605 this.client.clone(),
606 this.configuration.clone(),
607 )
608 });
609
610 let oauth_token = oauth_token.context("No OAuth token available")?;
611
612 let api_endpoint = match api_endpoint {
613 Some(endpoint) => endpoint,
614 None => {
615 let weak = this.downgrade();
616 Self::resolve_api_endpoint(&weak, &oauth_token, &configuration, &client, cx).await?
617 }
618 };
619
620 Ok((client, oauth_token, api_endpoint, configuration))
621 }
622
623 async fn resolve_api_endpoint(
624 this: &WeakEntity<Self>,
625 oauth_token: &str,
626 configuration: &CopilotChatConfiguration,
627 client: &Arc<dyn HttpClient>,
628 cx: &mut AsyncApp,
629 ) -> Result<String> {
630 let api_endpoint = match discover_api_endpoint(oauth_token, configuration, client).await {
631 Ok(endpoint) => endpoint,
632 Err(error) => {
633 log::warn!(
634 "Failed to discover Copilot API endpoint via GraphQL, \
635 falling back to {DEFAULT_COPILOT_API_ENDPOINT}: {error:#}"
636 );
637 DEFAULT_COPILOT_API_ENDPOINT.to_string()
638 }
639 };
640
641 this.update(cx, |this, cx| {
642 this.api_endpoint = Some(api_endpoint.clone());
643 cx.notify();
644 })?;
645
646 Ok(api_endpoint)
647 }
648
649 pub fn set_configuration(
650 &mut self,
651 configuration: CopilotChatConfiguration,
652 cx: &mut Context<Self>,
653 ) {
654 let same_configuration = self.configuration == configuration;
655 self.configuration = configuration;
656 if !same_configuration {
657 self.api_endpoint = None;
658 cx.spawn(async move |this, cx| {
659 Self::update_models(&this, cx).await?;
660 Ok::<_, anyhow::Error>(())
661 })
662 .detach();
663 }
664 }
665}
666
667async fn get_models(
668 models_url: Arc<str>,
669 oauth_token: String,
670 client: Arc<dyn HttpClient>,
671) -> Result<Vec<Model>> {
672 let all_models = request_models(models_url, oauth_token, client).await?;
673
674 let mut models: Vec<Model> = all_models
675 .into_iter()
676 .filter(|model| {
677 model.model_picker_enabled
678 && model.capabilities.model_type.as_str() == "chat"
679 && model
680 .policy
681 .as_ref()
682 .is_none_or(|policy| policy.state == "enabled")
683 })
684 .collect();
685
686 if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
687 let default_model = models.remove(default_model_position);
688 models.insert(0, default_model);
689 }
690
691 Ok(models)
692}
693
694#[derive(Deserialize)]
695struct GraphQLResponse {
696 data: Option<GraphQLData>,
697}
698
699#[derive(Deserialize)]
700struct GraphQLData {
701 viewer: GraphQLViewer,
702}
703
704#[derive(Deserialize)]
705struct GraphQLViewer {
706 #[serde(rename = "copilotEndpoints")]
707 copilot_endpoints: GraphQLCopilotEndpoints,
708}
709
710#[derive(Deserialize)]
711struct GraphQLCopilotEndpoints {
712 api: String,
713}
714
715pub(crate) async fn discover_api_endpoint(
716 oauth_token: &str,
717 configuration: &CopilotChatConfiguration,
718 client: &Arc<dyn HttpClient>,
719) -> Result<String> {
720 let graphql_url = configuration.graphql_url();
721 let query = serde_json::json!({
722 "query": "query { viewer { copilotEndpoints { api } } }"
723 });
724
725 let request = HttpRequest::builder()
726 .method(Method::POST)
727 .uri(graphql_url.as_str())
728 .header("Authorization", format!("Bearer {}", oauth_token))
729 .header("Content-Type", "application/json")
730 .body(AsyncBody::from(serde_json::to_string(&query)?))?;
731
732 let mut response = client.send(request).await?;
733
734 anyhow::ensure!(
735 response.status().is_success(),
736 "GraphQL endpoint discovery failed: {}",
737 response.status()
738 );
739
740 let mut body = Vec::new();
741 response.body_mut().read_to_end(&mut body).await?;
742 let body_str = std::str::from_utf8(&body)?;
743
744 let parsed: GraphQLResponse = serde_json::from_str(body_str)
745 .context("Failed to parse GraphQL response for Copilot endpoint discovery")?;
746
747 let data = parsed
748 .data
749 .context("GraphQL response contained no data field")?;
750
751 Ok(data.viewer.copilot_endpoints.api)
752}
753
754pub(crate) fn copilot_request_headers(
755 builder: http_client::Builder,
756 oauth_token: &str,
757 is_user_initiated: Option<bool>,
758) -> http_client::Builder {
759 builder
760 .header("Authorization", format!("Bearer {}", oauth_token))
761 .header("Content-Type", "application/json")
762 .header(
763 "Editor-Version",
764 format!(
765 "Zed/{}",
766 option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
767 ),
768 )
769 .when_some(is_user_initiated, |builder, is_user_initiated| {
770 builder.header(
771 "X-Initiator",
772 if is_user_initiated { "user" } else { "agent" },
773 )
774 })
775}
776
777async fn request_models(
778 models_url: Arc<str>,
779 oauth_token: String,
780 client: Arc<dyn HttpClient>,
781) -> Result<Vec<Model>> {
782 let request_builder = copilot_request_headers(
783 HttpRequest::builder()
784 .method(Method::GET)
785 .uri(models_url.as_ref()),
786 &oauth_token,
787 None,
788 )
789 .header("x-github-api-version", "2025-05-01");
790
791 let request = request_builder.body(AsyncBody::empty())?;
792
793 let mut response = client.send(request).await?;
794
795 anyhow::ensure!(
796 response.status().is_success(),
797 "Failed to request models: {}",
798 response.status()
799 );
800 let mut body = Vec::new();
801 response.body_mut().read_to_end(&mut body).await?;
802
803 let body_str = std::str::from_utf8(&body)?;
804
805 let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
806
807 Ok(models)
808}
809
810fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
811 serde_json::from_str::<serde_json::Value>(&contents)
812 .map(|v| {
813 v.as_object().and_then(|obj| {
814 obj.iter().find_map(|(key, value)| {
815 if key.starts_with(domain) {
816 value["oauth_token"].as_str().map(|v| v.to_string())
817 } else {
818 None
819 }
820 })
821 })
822 })
823 .ok()
824 .flatten()
825}
826
827async fn stream_completion(
828 client: Arc<dyn HttpClient>,
829 oauth_token: String,
830 completion_url: Arc<str>,
831 request: Request,
832 is_user_initiated: bool,
833) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
834 let is_vision_request = request.messages.iter().any(|message| match message {
835 ChatMessage::User { content }
836 | ChatMessage::Assistant { content, .. }
837 | ChatMessage::Tool { content, .. } => {
838 matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
839 }
840 _ => false,
841 });
842
843 let request_builder = copilot_request_headers(
844 HttpRequest::builder()
845 .method(Method::POST)
846 .uri(completion_url.as_ref()),
847 &oauth_token,
848 Some(is_user_initiated),
849 )
850 .when(is_vision_request, |builder| {
851 builder.header("Copilot-Vision-Request", is_vision_request.to_string())
852 });
853
854 let is_streaming = request.stream;
855
856 let json = serde_json::to_string(&request)?;
857 let request = request_builder.body(AsyncBody::from(json))?;
858 let mut response = client.send(request).await?;
859
860 if !response.status().is_success() {
861 let mut body = Vec::new();
862 response.body_mut().read_to_end(&mut body).await?;
863 let body_str = std::str::from_utf8(&body)?;
864 anyhow::bail!(
865 "Failed to connect to API: {} {}",
866 response.status(),
867 body_str
868 );
869 }
870
871 if is_streaming {
872 let reader = BufReader::new(response.into_body());
873 Ok(reader
874 .lines()
875 .filter_map(|line| async move {
876 match line {
877 Ok(line) => {
878 let line = line.strip_prefix("data: ")?;
879 if line.starts_with("[DONE]") {
880 return None;
881 }
882
883 match serde_json::from_str::<ResponseEvent>(line) {
884 Ok(response) => {
885 if response.choices.is_empty() {
886 None
887 } else {
888 Some(Ok(response))
889 }
890 }
891 Err(error) => Some(Err(anyhow!(error))),
892 }
893 }
894 Err(error) => Some(Err(anyhow!(error))),
895 }
896 })
897 .boxed())
898 } else {
899 let mut body = Vec::new();
900 response.body_mut().read_to_end(&mut body).await?;
901 let body_str = std::str::from_utf8(&body)?;
902 let response: ResponseEvent = serde_json::from_str(body_str)?;
903
904 Ok(futures::stream::once(async move { Ok(response) }).boxed())
905 }
906}
907
908#[cfg(test)]
909mod tests {
910 use super::*;
911
912 #[test]
913 fn test_resilient_model_schema_deserialize() {
914 let json = r#"{
915 "data": [
916 {
917 "billing": {
918 "is_premium": false,
919 "multiplier": 0
920 },
921 "capabilities": {
922 "family": "gpt-4",
923 "limits": {
924 "max_context_window_tokens": 32768,
925 "max_output_tokens": 4096,
926 "max_prompt_tokens": 32768
927 },
928 "object": "model_capabilities",
929 "supports": { "streaming": true, "tool_calls": true },
930 "tokenizer": "cl100k_base",
931 "type": "chat"
932 },
933 "id": "gpt-4",
934 "is_chat_default": false,
935 "is_chat_fallback": false,
936 "model_picker_enabled": false,
937 "name": "GPT 4",
938 "object": "model",
939 "preview": false,
940 "vendor": "Azure OpenAI",
941 "version": "gpt-4-0613"
942 },
943 {
944 "some-unknown-field": 123
945 },
946 {
947 "billing": {
948 "is_premium": true,
949 "multiplier": 1,
950 "restricted_to": [
951 "pro",
952 "pro_plus",
953 "business",
954 "enterprise"
955 ]
956 },
957 "capabilities": {
958 "family": "claude-3.7-sonnet",
959 "limits": {
960 "max_context_window_tokens": 200000,
961 "max_output_tokens": 16384,
962 "max_prompt_tokens": 90000,
963 "vision": {
964 "max_prompt_image_size": 3145728,
965 "max_prompt_images": 1,
966 "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
967 }
968 },
969 "object": "model_capabilities",
970 "supports": {
971 "parallel_tool_calls": true,
972 "streaming": true,
973 "tool_calls": true,
974 "vision": true
975 },
976 "tokenizer": "o200k_base",
977 "type": "chat"
978 },
979 "id": "claude-3.7-sonnet",
980 "is_chat_default": false,
981 "is_chat_fallback": false,
982 "model_picker_enabled": true,
983 "name": "Claude 3.7 Sonnet",
984 "object": "model",
985 "policy": {
986 "state": "enabled",
987 "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
988 },
989 "preview": false,
990 "vendor": "Anthropic",
991 "version": "claude-3.7-sonnet"
992 }
993 ],
994 "object": "list"
995 }"#;
996
997 let schema: ModelSchema = serde_json::from_str(json).unwrap();
998
999 assert_eq!(schema.data.len(), 2);
1000 assert_eq!(schema.data[0].id, "gpt-4");
1001 assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
1002 }
1003
1004 #[test]
1005 fn test_unknown_vendor_resilience() {
1006 let json = r#"{
1007 "data": [
1008 {
1009 "billing": {
1010 "is_premium": false,
1011 "multiplier": 1
1012 },
1013 "capabilities": {
1014 "family": "future-model",
1015 "limits": {
1016 "max_context_window_tokens": 128000,
1017 "max_output_tokens": 8192,
1018 "max_prompt_tokens": 120000
1019 },
1020 "object": "model_capabilities",
1021 "supports": { "streaming": true, "tool_calls": true },
1022 "type": "chat"
1023 },
1024 "id": "future-model-v1",
1025 "is_chat_default": false,
1026 "is_chat_fallback": false,
1027 "model_picker_enabled": true,
1028 "name": "Future Model v1",
1029 "object": "model",
1030 "preview": false,
1031 "vendor": "SomeNewVendor",
1032 "version": "v1.0"
1033 }
1034 ],
1035 "object": "list"
1036 }"#;
1037
1038 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1039
1040 assert_eq!(schema.data.len(), 1);
1041 assert_eq!(schema.data[0].id, "future-model-v1");
1042 assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1043 }
1044
1045 #[test]
1046 fn test_max_token_count_returns_context_window_not_prompt_tokens() {
1047 let json = r#"{
1048 "data": [
1049 {
1050 "billing": { "is_premium": true, "multiplier": 1 },
1051 "capabilities": {
1052 "family": "claude-sonnet-4",
1053 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1054 "object": "model_capabilities",
1055 "supports": { "streaming": true, "tool_calls": true },
1056 "type": "chat"
1057 },
1058 "id": "claude-sonnet-4",
1059 "is_chat_default": false,
1060 "is_chat_fallback": false,
1061 "model_picker_enabled": true,
1062 "name": "Claude Sonnet 4",
1063 "object": "model",
1064 "preview": false,
1065 "vendor": "Anthropic",
1066 "version": "claude-sonnet-4"
1067 },
1068 {
1069 "billing": { "is_premium": false, "multiplier": 1 },
1070 "capabilities": {
1071 "family": "gpt-4o",
1072 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1073 "object": "model_capabilities",
1074 "supports": { "streaming": true, "tool_calls": true },
1075 "type": "chat"
1076 },
1077 "id": "gpt-4o",
1078 "is_chat_default": true,
1079 "is_chat_fallback": false,
1080 "model_picker_enabled": true,
1081 "name": "GPT-4o",
1082 "object": "model",
1083 "preview": false,
1084 "vendor": "Azure OpenAI",
1085 "version": "gpt-4o"
1086 }
1087 ],
1088 "object": "list"
1089 }"#;
1090
1091 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1092
1093 // max_token_count() should return context window (200000), not prompt tokens (90000)
1094 assert_eq!(schema.data[0].max_token_count(), 200000);
1095
1096 // GPT-4o should return 128000 (context window), not 110000 (prompt tokens)
1097 assert_eq!(schema.data[1].max_token_count(), 128000);
1098 }
1099
1100 #[test]
1101 fn test_models_with_pending_policy_deserialize() {
1102 // This test verifies that models with policy states other than "enabled"
1103 // (such as "pending" or "requires_consent") are properly deserialized.
1104 // Note: These models will be filtered out by get_models() and won't appear
1105 // in the model picker until the user enables them on GitHub.
1106 let json = r#"{
1107 "data": [
1108 {
1109 "billing": { "is_premium": true, "multiplier": 1 },
1110 "capabilities": {
1111 "family": "claude-sonnet-4",
1112 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1113 "object": "model_capabilities",
1114 "supports": { "streaming": true, "tool_calls": true },
1115 "type": "chat"
1116 },
1117 "id": "claude-sonnet-4",
1118 "is_chat_default": false,
1119 "is_chat_fallback": false,
1120 "model_picker_enabled": true,
1121 "name": "Claude Sonnet 4",
1122 "object": "model",
1123 "policy": {
1124 "state": "pending",
1125 "terms": "Enable access to Claude models from Anthropic."
1126 },
1127 "preview": false,
1128 "vendor": "Anthropic",
1129 "version": "claude-sonnet-4"
1130 },
1131 {
1132 "billing": { "is_premium": true, "multiplier": 1 },
1133 "capabilities": {
1134 "family": "claude-opus-4",
1135 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1136 "object": "model_capabilities",
1137 "supports": { "streaming": true, "tool_calls": true },
1138 "type": "chat"
1139 },
1140 "id": "claude-opus-4",
1141 "is_chat_default": false,
1142 "is_chat_fallback": false,
1143 "model_picker_enabled": true,
1144 "name": "Claude Opus 4",
1145 "object": "model",
1146 "policy": {
1147 "state": "requires_consent",
1148 "terms": "Enable access to Claude models from Anthropic."
1149 },
1150 "preview": false,
1151 "vendor": "Anthropic",
1152 "version": "claude-opus-4"
1153 }
1154 ],
1155 "object": "list"
1156 }"#;
1157
1158 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1159
1160 // Both models should deserialize successfully (filtering happens in get_models)
1161 assert_eq!(schema.data.len(), 2);
1162 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1163 assert_eq!(schema.data[1].id, "claude-opus-4");
1164 }
1165
1166 #[test]
1167 fn test_multiple_anthropic_models_preserved() {
1168 // This test verifies that multiple Claude models from Anthropic
1169 // are all preserved and not incorrectly deduplicated.
1170 // This was the root cause of issue #47540.
1171 let json = r#"{
1172 "data": [
1173 {
1174 "billing": { "is_premium": true, "multiplier": 1 },
1175 "capabilities": {
1176 "family": "claude-sonnet-4",
1177 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1178 "object": "model_capabilities",
1179 "supports": { "streaming": true, "tool_calls": true },
1180 "type": "chat"
1181 },
1182 "id": "claude-sonnet-4",
1183 "is_chat_default": false,
1184 "is_chat_fallback": false,
1185 "model_picker_enabled": true,
1186 "name": "Claude Sonnet 4",
1187 "object": "model",
1188 "preview": false,
1189 "vendor": "Anthropic",
1190 "version": "claude-sonnet-4"
1191 },
1192 {
1193 "billing": { "is_premium": true, "multiplier": 1 },
1194 "capabilities": {
1195 "family": "claude-opus-4",
1196 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1197 "object": "model_capabilities",
1198 "supports": { "streaming": true, "tool_calls": true },
1199 "type": "chat"
1200 },
1201 "id": "claude-opus-4",
1202 "is_chat_default": false,
1203 "is_chat_fallback": false,
1204 "model_picker_enabled": true,
1205 "name": "Claude Opus 4",
1206 "object": "model",
1207 "preview": false,
1208 "vendor": "Anthropic",
1209 "version": "claude-opus-4"
1210 },
1211 {
1212 "billing": { "is_premium": true, "multiplier": 1 },
1213 "capabilities": {
1214 "family": "claude-sonnet-4.5",
1215 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1216 "object": "model_capabilities",
1217 "supports": { "streaming": true, "tool_calls": true },
1218 "type": "chat"
1219 },
1220 "id": "claude-sonnet-4.5",
1221 "is_chat_default": false,
1222 "is_chat_fallback": false,
1223 "model_picker_enabled": true,
1224 "name": "Claude Sonnet 4.5",
1225 "object": "model",
1226 "preview": false,
1227 "vendor": "Anthropic",
1228 "version": "claude-sonnet-4.5"
1229 }
1230 ],
1231 "object": "list"
1232 }"#;
1233
1234 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1235
1236 // All three Anthropic models should be preserved
1237 assert_eq!(schema.data.len(), 3);
1238 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1239 assert_eq!(schema.data[1].id, "claude-opus-4");
1240 assert_eq!(schema.data[2].id, "claude-sonnet-4.5");
1241 }
1242
1243 #[test]
1244 fn test_models_with_same_family_both_preserved() {
1245 // Test that models sharing the same family (e.g., thinking variants)
1246 // are both preserved in the model list.
1247 let json = r#"{
1248 "data": [
1249 {
1250 "billing": { "is_premium": true, "multiplier": 1 },
1251 "capabilities": {
1252 "family": "claude-sonnet-4",
1253 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1254 "object": "model_capabilities",
1255 "supports": { "streaming": true, "tool_calls": true },
1256 "type": "chat"
1257 },
1258 "id": "claude-sonnet-4",
1259 "is_chat_default": false,
1260 "is_chat_fallback": false,
1261 "model_picker_enabled": true,
1262 "name": "Claude Sonnet 4",
1263 "object": "model",
1264 "preview": false,
1265 "vendor": "Anthropic",
1266 "version": "claude-sonnet-4"
1267 },
1268 {
1269 "billing": { "is_premium": true, "multiplier": 1 },
1270 "capabilities": {
1271 "family": "claude-sonnet-4",
1272 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1273 "object": "model_capabilities",
1274 "supports": { "streaming": true, "tool_calls": true },
1275 "type": "chat"
1276 },
1277 "id": "claude-sonnet-4-thinking",
1278 "is_chat_default": false,
1279 "is_chat_fallback": false,
1280 "model_picker_enabled": true,
1281 "name": "Claude Sonnet 4 (Thinking)",
1282 "object": "model",
1283 "preview": false,
1284 "vendor": "Anthropic",
1285 "version": "claude-sonnet-4-thinking"
1286 }
1287 ],
1288 "object": "list"
1289 }"#;
1290
1291 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1292
1293 // Both models should be preserved even though they share the same family
1294 assert_eq!(schema.data.len(), 2);
1295 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1296 assert_eq!(schema.data[1].id, "claude-sonnet-4-thinking");
1297 }
1298
1299 #[test]
1300 fn test_mixed_vendor_models_all_preserved() {
1301 // Test that models from different vendors are all preserved.
1302 let json = r#"{
1303 "data": [
1304 {
1305 "billing": { "is_premium": false, "multiplier": 1 },
1306 "capabilities": {
1307 "family": "gpt-4o",
1308 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1309 "object": "model_capabilities",
1310 "supports": { "streaming": true, "tool_calls": true },
1311 "type": "chat"
1312 },
1313 "id": "gpt-4o",
1314 "is_chat_default": true,
1315 "is_chat_fallback": false,
1316 "model_picker_enabled": true,
1317 "name": "GPT-4o",
1318 "object": "model",
1319 "preview": false,
1320 "vendor": "Azure OpenAI",
1321 "version": "gpt-4o"
1322 },
1323 {
1324 "billing": { "is_premium": true, "multiplier": 1 },
1325 "capabilities": {
1326 "family": "claude-sonnet-4",
1327 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1328 "object": "model_capabilities",
1329 "supports": { "streaming": true, "tool_calls": true },
1330 "type": "chat"
1331 },
1332 "id": "claude-sonnet-4",
1333 "is_chat_default": false,
1334 "is_chat_fallback": false,
1335 "model_picker_enabled": true,
1336 "name": "Claude Sonnet 4",
1337 "object": "model",
1338 "preview": false,
1339 "vendor": "Anthropic",
1340 "version": "claude-sonnet-4"
1341 },
1342 {
1343 "billing": { "is_premium": true, "multiplier": 1 },
1344 "capabilities": {
1345 "family": "gemini-2.0-flash",
1346 "limits": { "max_context_window_tokens": 1000000, "max_output_tokens": 8192, "max_prompt_tokens": 900000 },
1347 "object": "model_capabilities",
1348 "supports": { "streaming": true, "tool_calls": true },
1349 "type": "chat"
1350 },
1351 "id": "gemini-2.0-flash",
1352 "is_chat_default": false,
1353 "is_chat_fallback": false,
1354 "model_picker_enabled": true,
1355 "name": "Gemini 2.0 Flash",
1356 "object": "model",
1357 "preview": false,
1358 "vendor": "Google",
1359 "version": "gemini-2.0-flash"
1360 }
1361 ],
1362 "object": "list"
1363 }"#;
1364
1365 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1366
1367 // All three models from different vendors should be preserved
1368 assert_eq!(schema.data.len(), 3);
1369 assert_eq!(schema.data[0].id, "gpt-4o");
1370 assert_eq!(schema.data[1].id, "claude-sonnet-4");
1371 assert_eq!(schema.data[2].id, "gemini-2.0-flash");
1372 }
1373
1374 #[test]
1375 fn test_model_with_messages_endpoint_deserializes() {
1376 // Anthropic Claude models use /v1/messages endpoint.
1377 // This test verifies such models deserialize correctly (issue #47540 root cause).
1378 let json = r#"{
1379 "data": [
1380 {
1381 "billing": { "is_premium": true, "multiplier": 1 },
1382 "capabilities": {
1383 "family": "claude-sonnet-4",
1384 "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1385 "object": "model_capabilities",
1386 "supports": { "streaming": true, "tool_calls": true },
1387 "type": "chat"
1388 },
1389 "id": "claude-sonnet-4",
1390 "is_chat_default": false,
1391 "is_chat_fallback": false,
1392 "model_picker_enabled": true,
1393 "name": "Claude Sonnet 4",
1394 "object": "model",
1395 "preview": false,
1396 "vendor": "Anthropic",
1397 "version": "claude-sonnet-4",
1398 "supported_endpoints": ["/v1/messages"]
1399 }
1400 ],
1401 "object": "list"
1402 }"#;
1403
1404 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1405
1406 assert_eq!(schema.data.len(), 1);
1407 assert_eq!(schema.data[0].id, "claude-sonnet-4");
1408 assert_eq!(
1409 schema.data[0].supported_endpoints,
1410 vec![ModelSupportedEndpoint::Messages]
1411 );
1412 }
1413
1414 #[test]
1415 fn test_model_with_unknown_endpoint_deserializes() {
1416 // Future-proofing: unknown endpoints should deserialize to Unknown variant
1417 // instead of causing the entire model to fail deserialization.
1418 let json = r#"{
1419 "data": [
1420 {
1421 "billing": { "is_premium": false, "multiplier": 1 },
1422 "capabilities": {
1423 "family": "future-model",
1424 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 8192, "max_prompt_tokens": 120000 },
1425 "object": "model_capabilities",
1426 "supports": { "streaming": true, "tool_calls": true },
1427 "type": "chat"
1428 },
1429 "id": "future-model-v2",
1430 "is_chat_default": false,
1431 "is_chat_fallback": false,
1432 "model_picker_enabled": true,
1433 "name": "Future Model v2",
1434 "object": "model",
1435 "preview": false,
1436 "vendor": "OpenAI",
1437 "version": "v2.0",
1438 "supported_endpoints": ["/v2/completions", "/chat/completions"]
1439 }
1440 ],
1441 "object": "list"
1442 }"#;
1443
1444 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1445
1446 assert_eq!(schema.data.len(), 1);
1447 assert_eq!(schema.data[0].id, "future-model-v2");
1448 assert_eq!(
1449 schema.data[0].supported_endpoints,
1450 vec![
1451 ModelSupportedEndpoint::Unknown,
1452 ModelSupportedEndpoint::ChatCompletions
1453 ]
1454 );
1455 }
1456
1457 #[test]
1458 fn test_model_with_multiple_endpoints() {
1459 // Test model with multiple supported endpoints (common for newer models).
1460 let json = r#"{
1461 "data": [
1462 {
1463 "billing": { "is_premium": true, "multiplier": 1 },
1464 "capabilities": {
1465 "family": "gpt-4o",
1466 "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1467 "object": "model_capabilities",
1468 "supports": { "streaming": true, "tool_calls": true },
1469 "type": "chat"
1470 },
1471 "id": "gpt-4o",
1472 "is_chat_default": true,
1473 "is_chat_fallback": false,
1474 "model_picker_enabled": true,
1475 "name": "GPT-4o",
1476 "object": "model",
1477 "preview": false,
1478 "vendor": "OpenAI",
1479 "version": "gpt-4o",
1480 "supported_endpoints": ["/chat/completions", "/responses"]
1481 }
1482 ],
1483 "object": "list"
1484 }"#;
1485
1486 let schema: ModelSchema = serde_json::from_str(json).unwrap();
1487
1488 assert_eq!(schema.data.len(), 1);
1489 assert_eq!(schema.data[0].id, "gpt-4o");
1490 assert_eq!(
1491 schema.data[0].supported_endpoints,
1492 vec![
1493 ModelSupportedEndpoint::ChatCompletions,
1494 ModelSupportedEndpoint::Responses
1495 ]
1496 );
1497 }
1498
1499 #[test]
1500 fn test_supports_response_method() {
1501 // Test the supports_response() method which determines endpoint routing.
1502 let model_with_responses_only = Model {
1503 billing: ModelBilling {
1504 is_premium: false,
1505 multiplier: 1.0,
1506 restricted_to: None,
1507 },
1508 capabilities: ModelCapabilities {
1509 family: "test".to_string(),
1510 limits: ModelLimits::default(),
1511 supports: ModelSupportedFeatures {
1512 streaming: true,
1513 tool_calls: true,
1514 parallel_tool_calls: false,
1515 vision: false,
1516 },
1517 model_type: "chat".to_string(),
1518 tokenizer: None,
1519 },
1520 id: "test-model".to_string(),
1521 name: "Test Model".to_string(),
1522 policy: None,
1523 vendor: ModelVendor::OpenAI,
1524 is_chat_default: false,
1525 is_chat_fallback: false,
1526 model_picker_enabled: true,
1527 supported_endpoints: vec![ModelSupportedEndpoint::Responses],
1528 };
1529
1530 let model_with_chat_completions = Model {
1531 supported_endpoints: vec![ModelSupportedEndpoint::ChatCompletions],
1532 ..model_with_responses_only.clone()
1533 };
1534
1535 let model_with_both = Model {
1536 supported_endpoints: vec![
1537 ModelSupportedEndpoint::ChatCompletions,
1538 ModelSupportedEndpoint::Responses,
1539 ],
1540 ..model_with_responses_only.clone()
1541 };
1542
1543 let model_with_messages = Model {
1544 supported_endpoints: vec![ModelSupportedEndpoint::Messages],
1545 ..model_with_responses_only.clone()
1546 };
1547
1548 // Only /responses endpoint -> supports_response = true
1549 assert!(model_with_responses_only.supports_response());
1550
1551 // Only /chat/completions endpoint -> supports_response = false
1552 assert!(!model_with_chat_completions.supports_response());
1553
1554 // Both endpoints (has /chat/completions) -> supports_response = false
1555 assert!(!model_with_both.supports_response());
1556
1557 // Only /v1/messages endpoint -> supports_response = false (doesn't have /responses)
1558 assert!(!model_with_messages.supports_response());
1559 }
1560}