1use anyhow::Result;
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
7use language_model::{
8 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
9 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
11 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
12 env_var,
13};
14use open_ai::ResponseStreamEvent;
15use serde::Deserialize;
16pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
17pub use settings::VercelAiGatewayAvailableModel as AvailableModel;
18use settings::{Settings, SettingsStore};
19use std::sync::{Arc, LazyLock};
20use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
21use ui_input::InputField;
22use util::ResultExt;
23
24const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel_ai_gateway");
25const PROVIDER_NAME: LanguageModelProviderName =
26 LanguageModelProviderName::new("Vercel AI Gateway");
27
28const API_URL: &str = "https://ai-gateway.vercel.sh/v1";
29const API_KEY_ENV_VAR_NAME: &str = "VERCEL_AI_GATEWAY_API_KEY";
30static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
31
32#[derive(Default, Clone, Debug, PartialEq)]
33pub struct VercelAiGatewaySettings {
34 pub api_url: String,
35 pub available_models: Vec<AvailableModel>,
36}
37
38pub struct VercelAiGatewayLanguageModelProvider {
39 http_client: Arc<dyn HttpClient>,
40 state: Entity<State>,
41}
42
43pub struct State {
44 api_key_state: ApiKeyState,
45 credentials_provider: Arc<dyn CredentialsProvider>,
46 http_client: Arc<dyn HttpClient>,
47 available_models: Vec<AvailableModel>,
48 fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
49}
50
51impl State {
52 fn is_authenticated(&self) -> bool {
53 self.api_key_state.has_key()
54 }
55
56 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
57 let credentials_provider = self.credentials_provider.clone();
58 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
59 self.api_key_state.store(
60 api_url,
61 api_key,
62 |this| &mut this.api_key_state,
63 credentials_provider,
64 cx,
65 )
66 }
67
68 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
69 let credentials_provider = self.credentials_provider.clone();
70 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
71 let task = self.api_key_state.load_if_needed(
72 api_url,
73 |this| &mut this.api_key_state,
74 credentials_provider,
75 cx,
76 );
77
78 cx.spawn(async move |this, cx| {
79 let result = task.await;
80 this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
81 .ok();
82 result
83 })
84 }
85
86 fn fetch_models(
87 &mut self,
88 cx: &mut Context<Self>,
89 ) -> Task<Result<(), LanguageModelCompletionError>> {
90 let http_client = self.http_client.clone();
91 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
92 let api_key = self.api_key_state.key(&api_url);
93 cx.spawn(async move |this, cx| {
94 let models = list_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
95 this.update(cx, |this, cx| {
96 this.available_models = models;
97 cx.notify();
98 })
99 .map_err(|e| LanguageModelCompletionError::Other(e))?;
100 Ok(())
101 })
102 }
103
104 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
105 if self.is_authenticated() {
106 let task = self.fetch_models(cx);
107 self.fetch_models_task.replace(task);
108 } else {
109 self.available_models = Vec::new();
110 }
111 }
112}
113
114impl VercelAiGatewayLanguageModelProvider {
115 pub fn new(
116 http_client: Arc<dyn HttpClient>,
117 credentials_provider: Arc<dyn CredentialsProvider>,
118 cx: &mut App,
119 ) -> Self {
120 let state = cx.new(|cx| {
121 cx.observe_global::<SettingsStore>({
122 let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
123 move |this: &mut State, cx| {
124 let current_settings = VercelAiGatewayLanguageModelProvider::settings(cx);
125 if current_settings != &last_settings {
126 last_settings = current_settings.clone();
127 this.authenticate(cx).detach();
128 cx.notify();
129 }
130 }
131 })
132 .detach();
133 State {
134 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
135 credentials_provider,
136 http_client: http_client.clone(),
137 available_models: Vec::new(),
138 fetch_models_task: None,
139 }
140 });
141
142 Self { http_client, state }
143 }
144
145 fn settings(cx: &App) -> &VercelAiGatewaySettings {
146 &crate::AllLanguageModelSettings::get_global(cx).vercel_ai_gateway
147 }
148
149 fn api_url(cx: &App) -> SharedString {
150 let api_url = &Self::settings(cx).api_url;
151 if api_url.is_empty() {
152 API_URL.into()
153 } else {
154 SharedString::new(api_url.as_str())
155 }
156 }
157
158 fn default_available_model() -> AvailableModel {
159 AvailableModel {
160 name: "openai/gpt-5.3-codex".to_string(),
161 display_name: Some("GPT 5.3 Codex".to_string()),
162 max_tokens: 400_000,
163 max_output_tokens: Some(128_000),
164 max_completion_tokens: None,
165 capabilities: ModelCapabilities::default(),
166 }
167 }
168
169 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
170 Arc::new(VercelAiGatewayLanguageModel {
171 id: LanguageModelId::from(model.name.clone()),
172 model,
173 state: self.state.clone(),
174 http_client: self.http_client.clone(),
175 request_limiter: RateLimiter::new(4),
176 })
177 }
178}
179
180impl LanguageModelProviderState for VercelAiGatewayLanguageModelProvider {
181 type ObservableEntity = State;
182
183 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
184 Some(self.state.clone())
185 }
186}
187
188impl LanguageModelProvider for VercelAiGatewayLanguageModelProvider {
189 fn id(&self) -> LanguageModelProviderId {
190 PROVIDER_ID
191 }
192
193 fn name(&self) -> LanguageModelProviderName {
194 PROVIDER_NAME
195 }
196
197 fn icon(&self) -> IconOrSvg {
198 IconOrSvg::Icon(IconName::AiVercel)
199 }
200
201 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
202 Some(self.create_language_model(Self::default_available_model()))
203 }
204
205 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
206 None
207 }
208
209 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
210 let mut models = BTreeMap::default();
211
212 let default_model = Self::default_available_model();
213 models.insert(default_model.name.clone(), default_model);
214
215 for model in self.state.read(cx).available_models.clone() {
216 models.insert(model.name.clone(), model);
217 }
218
219 for model in &Self::settings(cx).available_models {
220 models.insert(model.name.clone(), model.clone());
221 }
222
223 models
224 .into_values()
225 .map(|model| self.create_language_model(model))
226 .collect()
227 }
228
229 fn is_authenticated(&self, cx: &App) -> bool {
230 self.state.read(cx).is_authenticated()
231 }
232
233 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
234 self.state.update(cx, |state, cx| state.authenticate(cx))
235 }
236
237 fn configuration_view(
238 &self,
239 _target_agent: language_model::ConfigurationViewTargetAgent,
240 window: &mut Window,
241 cx: &mut App,
242 ) -> AnyView {
243 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
244 .into()
245 }
246
247 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
248 self.state
249 .update(cx, |state, cx| state.set_api_key(None, cx))
250 }
251}
252
253pub struct VercelAiGatewayLanguageModel {
254 id: LanguageModelId,
255 model: AvailableModel,
256 state: Entity<State>,
257 http_client: Arc<dyn HttpClient>,
258 request_limiter: RateLimiter,
259}
260
261impl VercelAiGatewayLanguageModel {
262 fn stream_open_ai(
263 &self,
264 request: open_ai::Request,
265 cx: &AsyncApp,
266 ) -> BoxFuture<
267 'static,
268 Result<
269 futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
270 LanguageModelCompletionError,
271 >,
272 > {
273 let http_client = self.http_client.clone();
274 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
275 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
276 (state.api_key_state.key(&api_url), api_url)
277 });
278
279 let future = self.request_limiter.stream(async move {
280 let provider = PROVIDER_NAME;
281 let Some(api_key) = api_key else {
282 return Err(LanguageModelCompletionError::NoApiKey { provider });
283 };
284 let request = open_ai::stream_completion(
285 http_client.as_ref(),
286 provider.0.as_str(),
287 &api_url,
288 &api_key,
289 request,
290 );
291 let response = request.await.map_err(map_open_ai_error)?;
292 Ok(response)
293 });
294
295 async move { Ok(future.await?.boxed()) }.boxed()
296 }
297}
298
299fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionError {
300 match error {
301 open_ai::RequestError::HttpResponseError {
302 status_code,
303 body,
304 headers,
305 ..
306 } => {
307 let retry_after = headers
308 .get(http::header::RETRY_AFTER)
309 .and_then(|value| value.to_str().ok()?.parse::<u64>().ok())
310 .map(std::time::Duration::from_secs);
311
312 LanguageModelCompletionError::from_http_status(
313 PROVIDER_NAME,
314 status_code,
315 extract_error_message(&body),
316 retry_after,
317 )
318 }
319 open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
320 }
321}
322
323fn extract_error_message(body: &str) -> String {
324 let json = match serde_json::from_str::<serde_json::Value>(body) {
325 Ok(json) => json,
326 Err(_) => return body.to_string(),
327 };
328
329 let message = json
330 .get("error")
331 .and_then(|value| {
332 value
333 .get("message")
334 .and_then(serde_json::Value::as_str)
335 .or_else(|| value.as_str())
336 })
337 .or_else(|| json.get("message").and_then(serde_json::Value::as_str))
338 .map(ToString::to_string)
339 .unwrap_or_else(|| body.to_string());
340
341 clean_error_message(&message)
342}
343
344fn clean_error_message(message: &str) -> String {
345 let lower = message.to_lowercase();
346
347 if lower.contains("vercel_oidc_token") && lower.contains("oidc token") {
348 return "Authentication failed for Vercel AI Gateway. Use a Vercel AI Gateway key (vck_...).\nCreate or manage keys in Vercel AI Gateway console.\nIf this persists, regenerate the key and update it in Vercel AI Gateway provider settings in Zed.".to_string();
349 }
350
351 if lower.contains("invalid api key") || lower.contains("invalid_api_key") {
352 return "Authentication failed for Vercel AI Gateway. Check that your Vercel AI Gateway key starts with vck_ and is active.".to_string();
353 }
354
355 message.to_string()
356}
357
358fn has_tag(tags: &[String], expected: &str) -> bool {
359 tags.iter()
360 .any(|tag| tag.trim().eq_ignore_ascii_case(expected))
361}
362
363impl LanguageModel for VercelAiGatewayLanguageModel {
364 fn id(&self) -> LanguageModelId {
365 self.id.clone()
366 }
367
368 fn name(&self) -> LanguageModelName {
369 LanguageModelName::from(
370 self.model
371 .display_name
372 .clone()
373 .unwrap_or_else(|| self.model.name.clone()),
374 )
375 }
376
377 fn provider_id(&self) -> LanguageModelProviderId {
378 PROVIDER_ID
379 }
380
381 fn provider_name(&self) -> LanguageModelProviderName {
382 PROVIDER_NAME
383 }
384
385 fn supports_tools(&self) -> bool {
386 self.model.capabilities.tools
387 }
388
389 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
390 LanguageModelToolSchemaFormat::JsonSchemaSubset
391 }
392
393 fn supports_images(&self) -> bool {
394 self.model.capabilities.images
395 }
396
397 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
398 match choice {
399 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
400 LanguageModelToolChoice::Any => self.model.capabilities.tools,
401 LanguageModelToolChoice::None => true,
402 }
403 }
404
405 fn supports_streaming_tools(&self) -> bool {
406 true
407 }
408
409 fn supports_split_token_display(&self) -> bool {
410 true
411 }
412
413 fn telemetry_id(&self) -> String {
414 format!("vercel_ai_gateway/{}", self.model.name)
415 }
416
417 fn max_token_count(&self) -> u64 {
418 self.model.max_tokens
419 }
420
421 fn max_output_tokens(&self) -> Option<u64> {
422 self.model.max_output_tokens
423 }
424
425 fn stream_completion(
426 &self,
427 request: LanguageModelRequest,
428 cx: &AsyncApp,
429 ) -> BoxFuture<
430 'static,
431 Result<
432 futures::stream::BoxStream<
433 'static,
434 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
435 >,
436 LanguageModelCompletionError,
437 >,
438 > {
439 let request = crate::provider::open_ai::into_open_ai(
440 request,
441 &self.model.name,
442 self.model.capabilities.parallel_tool_calls,
443 self.model.capabilities.prompt_cache_key,
444 self.max_output_tokens(),
445 None,
446 false,
447 );
448 let completions = self.stream_open_ai(request, cx);
449 async move {
450 let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
451 Ok(mapper.map_stream(completions.await?).boxed())
452 }
453 .boxed()
454 }
455}
456
457#[derive(Deserialize)]
458struct ModelsResponse {
459 data: Vec<ApiModel>,
460}
461
462#[derive(Deserialize)]
463struct ApiModel {
464 id: String,
465 name: Option<String>,
466 context_window: Option<u64>,
467 max_tokens: Option<u64>,
468 #[serde(default)]
469 r#type: Option<String>,
470 #[serde(default)]
471 supported_parameters: Vec<String>,
472 #[serde(default)]
473 tags: Vec<String>,
474 architecture: Option<ApiModelArchitecture>,
475}
476
477#[derive(Deserialize)]
478struct ApiModelArchitecture {
479 #[serde(default)]
480 input_modalities: Vec<String>,
481}
482
483async fn list_models(
484 client: &dyn HttpClient,
485 api_url: &str,
486 api_key: Option<&str>,
487) -> Result<Vec<AvailableModel>, LanguageModelCompletionError> {
488 let uri = format!("{api_url}/models?include_mappings=true");
489 let mut request_builder = HttpRequest::builder()
490 .method(Method::GET)
491 .uri(uri)
492 .header("Accept", "application/json");
493 if let Some(api_key) = api_key {
494 request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
495 }
496 let request = request_builder
497 .body(AsyncBody::default())
498 .map_err(|error| LanguageModelCompletionError::BuildRequestBody {
499 provider: PROVIDER_NAME,
500 error,
501 })?;
502 let mut response =
503 client
504 .send(request)
505 .await
506 .map_err(|error| LanguageModelCompletionError::HttpSend {
507 provider: PROVIDER_NAME,
508 error,
509 })?;
510
511 let mut body = String::new();
512 response
513 .body_mut()
514 .read_to_string(&mut body)
515 .await
516 .map_err(|error| LanguageModelCompletionError::ApiReadResponseError {
517 provider: PROVIDER_NAME,
518 error,
519 })?;
520
521 if !response.status().is_success() {
522 return Err(LanguageModelCompletionError::from_http_status(
523 PROVIDER_NAME,
524 response.status(),
525 extract_error_message(&body),
526 None,
527 ));
528 }
529
530 let response: ModelsResponse = serde_json::from_str(&body).map_err(|error| {
531 LanguageModelCompletionError::DeserializeResponse {
532 provider: PROVIDER_NAME,
533 error,
534 }
535 })?;
536
537 let mut models = Vec::new();
538 for model in response.data {
539 if let Some(model_type) = model.r#type.as_deref()
540 && model_type != "language"
541 {
542 continue;
543 }
544 let supports_tools = model
545 .supported_parameters
546 .iter()
547 .any(|parameter| parameter == "tools")
548 || has_tag(&model.tags, "tool-use")
549 || has_tag(&model.tags, "tools");
550 let supports_images = model.architecture.is_some_and(|architecture| {
551 architecture
552 .input_modalities
553 .iter()
554 .any(|modality| modality == "image")
555 }) || has_tag(&model.tags, "vision")
556 || has_tag(&model.tags, "image-input");
557 let parallel_tool_calls = model
558 .supported_parameters
559 .iter()
560 .any(|parameter| parameter == "parallel_tool_calls");
561 let prompt_cache_key = model
562 .supported_parameters
563 .iter()
564 .any(|parameter| parameter == "prompt_cache_key" || parameter == "cache_control");
565 models.push(AvailableModel {
566 name: model.id.clone(),
567 display_name: model.name.or(Some(model.id)),
568 max_tokens: model.context_window.or(model.max_tokens).unwrap_or(128_000),
569 max_output_tokens: model.max_tokens,
570 max_completion_tokens: None,
571 capabilities: ModelCapabilities {
572 tools: supports_tools,
573 images: supports_images,
574 parallel_tool_calls,
575 prompt_cache_key,
576 chat_completions: true,
577 interleaved_reasoning: false,
578 },
579 });
580 }
581
582 Ok(models)
583}
584
585struct ConfigurationView {
586 api_key_editor: Entity<InputField>,
587 state: Entity<State>,
588 load_credentials_task: Option<Task<()>>,
589}
590
591impl ConfigurationView {
592 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
593 let api_key_editor =
594 cx.new(|cx| InputField::new(window, cx, "vck_000000000000000000000000000"));
595
596 cx.observe(&state, |_, _, cx| cx.notify()).detach();
597
598 let load_credentials_task = Some(cx.spawn_in(window, {
599 let state = state.clone();
600 async move |this, cx| {
601 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
602 let _ = task.await;
603 }
604 this.update(cx, |this, cx| {
605 this.load_credentials_task = None;
606 cx.notify();
607 })
608 .log_err();
609 }
610 }));
611
612 Self {
613 api_key_editor,
614 state,
615 load_credentials_task,
616 }
617 }
618
619 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
620 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
621 if api_key.is_empty() {
622 return;
623 }
624
625 self.api_key_editor
626 .update(cx, |editor, cx| editor.set_text("", window, cx));
627
628 let state = self.state.clone();
629 cx.spawn_in(window, async move |_, cx| {
630 state
631 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
632 .await
633 })
634 .detach_and_log_err(cx);
635 }
636
637 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
638 self.api_key_editor
639 .update(cx, |editor, cx| editor.set_text("", window, cx));
640
641 let state = self.state.clone();
642 cx.spawn_in(window, async move |_, cx| {
643 state
644 .update(cx, |state, cx| state.set_api_key(None, cx))
645 .await
646 })
647 .detach_and_log_err(cx);
648 }
649
650 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
651 !self.state.read(cx).is_authenticated()
652 }
653}
654
655impl Render for ConfigurationView {
656 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
657 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
658 let configured_card_label = if env_var_set {
659 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
660 } else {
661 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
662 if api_url == API_URL {
663 "API key configured".to_string()
664 } else {
665 format!("API key configured for {}", api_url)
666 }
667 };
668
669 if self.load_credentials_task.is_some() {
670 div().child(Label::new("Loading credentials...")).into_any()
671 } else if self.should_render_editor(cx) {
672 v_flex()
673 .size_full()
674 .on_action(cx.listener(Self::save_api_key))
675 .child(Label::new(
676 "To use Zed's agent with Vercel AI Gateway, you need to add an API key. Follow these steps:",
677 ))
678 .child(
679 List::new()
680 .child(
681 ListBulletItem::new("")
682 .child(Label::new("Create an API key in"))
683 .child(ButtonLink::new(
684 "Vercel AI Gateway's console",
685 "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys&title=Go+to+AI+Gateway",
686 )),
687 )
688 .child(ListBulletItem::new(
689 "Paste your API key below and hit enter to start using the assistant",
690 )),
691 )
692 .child(self.api_key_editor.clone())
693 .child(
694 Label::new(format!(
695 "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.",
696 ))
697 .size(LabelSize::Small)
698 .color(Color::Muted),
699 )
700 .into_any_element()
701 } else {
702 ConfiguredApiCard::new(configured_card_label)
703 .disabled(env_var_set)
704 .when(env_var_set, |this| {
705 this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
706 })
707 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
708 .into_any_element()
709 }
710 }
711}