1use anyhow::Result;
2use collections::BTreeMap;
3use credentials_provider::CredentialsProvider;
4use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
7use language_model::{
8 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
9 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
11 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
12 env_var,
13};
14use open_ai::ResponseStreamEvent;
15use serde::Deserialize;
16pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
17pub use settings::VercelAiGatewayAvailableModel as AvailableModel;
18use settings::{Settings, SettingsStore};
19use std::sync::{Arc, LazyLock};
20use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
21use ui_input::InputField;
22use util::ResultExt;
23
24const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel_ai_gateway");
25const PROVIDER_NAME: LanguageModelProviderName =
26 LanguageModelProviderName::new("Vercel AI Gateway");
27
28const API_URL: &str = "https://ai-gateway.vercel.sh/v1";
29const API_KEY_ENV_VAR_NAME: &str = "VERCEL_AI_GATEWAY_API_KEY";
30static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
31
32#[derive(Default, Clone, Debug, PartialEq)]
33pub struct VercelAiGatewaySettings {
34 pub api_url: String,
35 pub available_models: Vec<AvailableModel>,
36}
37
38pub struct VercelAiGatewayLanguageModelProvider {
39 http_client: Arc<dyn HttpClient>,
40 state: Entity<State>,
41}
42
43pub struct State {
44 api_key_state: ApiKeyState,
45 credentials_provider: Arc<dyn CredentialsProvider>,
46 http_client: Arc<dyn HttpClient>,
47 available_models: Vec<AvailableModel>,
48 fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
49}
50
51impl State {
52 fn is_authenticated(&self) -> bool {
53 self.api_key_state.has_key()
54 }
55
56 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
57 let credentials_provider = self.credentials_provider.clone();
58 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
59 self.api_key_state.store(
60 api_url,
61 api_key,
62 |this| &mut this.api_key_state,
63 credentials_provider,
64 cx,
65 )
66 }
67
68 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
69 let credentials_provider = self.credentials_provider.clone();
70 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
71 let task = self.api_key_state.load_if_needed(
72 api_url,
73 |this| &mut this.api_key_state,
74 credentials_provider,
75 cx,
76 );
77
78 cx.spawn(async move |this, cx| {
79 let result = task.await;
80 this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
81 .ok();
82 result
83 })
84 }
85
86 fn fetch_models(
87 &mut self,
88 cx: &mut Context<Self>,
89 ) -> Task<Result<(), LanguageModelCompletionError>> {
90 let http_client = self.http_client.clone();
91 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
92 let api_key = self.api_key_state.key(&api_url);
93 cx.spawn(async move |this, cx| {
94 let models = list_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
95 this.update(cx, |this, cx| {
96 this.available_models = models;
97 cx.notify();
98 })
99 .map_err(|e| LanguageModelCompletionError::Other(e))?;
100 Ok(())
101 })
102 }
103
104 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
105 if self.is_authenticated() {
106 let task = self.fetch_models(cx);
107 self.fetch_models_task.replace(task);
108 } else {
109 self.available_models = Vec::new();
110 }
111 }
112}
113
114impl VercelAiGatewayLanguageModelProvider {
115 pub fn new(
116 http_client: Arc<dyn HttpClient>,
117 credentials_provider: Arc<dyn CredentialsProvider>,
118 cx: &mut App,
119 ) -> Self {
120 let state = cx.new(|cx| {
121 cx.observe_global::<SettingsStore>({
122 let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
123 move |this: &mut State, cx| {
124 let current_settings = VercelAiGatewayLanguageModelProvider::settings(cx);
125 if current_settings != &last_settings {
126 last_settings = current_settings.clone();
127 this.authenticate(cx).detach();
128 cx.notify();
129 }
130 }
131 })
132 .detach();
133 State {
134 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
135 credentials_provider,
136 http_client: http_client.clone(),
137 available_models: Vec::new(),
138 fetch_models_task: None,
139 }
140 });
141
142 Self { http_client, state }
143 }
144
145 fn settings(cx: &App) -> &VercelAiGatewaySettings {
146 &crate::AllLanguageModelSettings::get_global(cx).vercel_ai_gateway
147 }
148
149 fn api_url(cx: &App) -> SharedString {
150 let api_url = &Self::settings(cx).api_url;
151 if api_url.is_empty() {
152 API_URL.into()
153 } else {
154 SharedString::new(api_url.as_str())
155 }
156 }
157
158 fn default_available_model() -> AvailableModel {
159 AvailableModel {
160 name: "openai/gpt-5.3-codex".to_string(),
161 display_name: Some("GPT 5.3 Codex".to_string()),
162 max_tokens: 400_000,
163 max_output_tokens: Some(128_000),
164 max_completion_tokens: None,
165 capabilities: ModelCapabilities::default(),
166 }
167 }
168
169 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
170 Arc::new(VercelAiGatewayLanguageModel {
171 id: LanguageModelId::from(model.name.clone()),
172 model,
173 state: self.state.clone(),
174 http_client: self.http_client.clone(),
175 request_limiter: RateLimiter::new(4),
176 })
177 }
178}
179
180impl LanguageModelProviderState for VercelAiGatewayLanguageModelProvider {
181 type ObservableEntity = State;
182
183 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
184 Some(self.state.clone())
185 }
186}
187
188impl LanguageModelProvider for VercelAiGatewayLanguageModelProvider {
189 fn id(&self) -> LanguageModelProviderId {
190 PROVIDER_ID
191 }
192
193 fn name(&self) -> LanguageModelProviderName {
194 PROVIDER_NAME
195 }
196
197 fn icon(&self) -> IconOrSvg {
198 IconOrSvg::Icon(IconName::AiVercel)
199 }
200
201 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
202 Some(self.create_language_model(Self::default_available_model()))
203 }
204
205 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
206 None
207 }
208
209 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
210 let mut models = BTreeMap::default();
211
212 let default_model = Self::default_available_model();
213 models.insert(default_model.name.clone(), default_model);
214
215 for model in self.state.read(cx).available_models.clone() {
216 models.insert(model.name.clone(), model);
217 }
218
219 for model in &Self::settings(cx).available_models {
220 models.insert(model.name.clone(), model.clone());
221 }
222
223 models
224 .into_values()
225 .map(|model| self.create_language_model(model))
226 .collect()
227 }
228
229 fn is_authenticated(&self, cx: &App) -> bool {
230 self.state.read(cx).is_authenticated()
231 }
232
233 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
234 self.state.update(cx, |state, cx| state.authenticate(cx))
235 }
236
237 fn configuration_view(
238 &self,
239 _target_agent: language_model::ConfigurationViewTargetAgent,
240 window: &mut Window,
241 cx: &mut App,
242 ) -> AnyView {
243 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
244 .into()
245 }
246
247 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
248 self.state
249 .update(cx, |state, cx| state.set_api_key(None, cx))
250 }
251}
252
253pub struct VercelAiGatewayLanguageModel {
254 id: LanguageModelId,
255 model: AvailableModel,
256 state: Entity<State>,
257 http_client: Arc<dyn HttpClient>,
258 request_limiter: RateLimiter,
259}
260
261impl VercelAiGatewayLanguageModel {
262 fn stream_open_ai(
263 &self,
264 request: open_ai::Request,
265 cx: &AsyncApp,
266 ) -> BoxFuture<
267 'static,
268 Result<
269 futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
270 LanguageModelCompletionError,
271 >,
272 > {
273 let http_client = self.http_client.clone();
274 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
275 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
276 (state.api_key_state.key(&api_url), api_url)
277 });
278
279 let future = self.request_limiter.stream(async move {
280 let provider = PROVIDER_NAME;
281 let Some(api_key) = api_key else {
282 return Err(LanguageModelCompletionError::NoApiKey { provider });
283 };
284 let request = open_ai::stream_completion(
285 http_client.as_ref(),
286 provider.0.as_str(),
287 &api_url,
288 &api_key,
289 request,
290 );
291 let response = request.await.map_err(map_open_ai_error)?;
292 Ok(response)
293 });
294
295 async move { Ok(future.await?.boxed()) }.boxed()
296 }
297}
298
299fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionError {
300 match error {
301 open_ai::RequestError::HttpResponseError {
302 status_code,
303 body,
304 headers,
305 ..
306 } => {
307 let retry_after = headers
308 .get(http::header::RETRY_AFTER)
309 .and_then(|value| value.to_str().ok()?.parse::<u64>().ok())
310 .map(std::time::Duration::from_secs);
311
312 LanguageModelCompletionError::from_http_status(
313 PROVIDER_NAME,
314 status_code,
315 extract_error_message(&body),
316 retry_after,
317 )
318 }
319 open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
320 }
321}
322
323fn extract_error_message(body: &str) -> String {
324 let json = match serde_json::from_str::<serde_json::Value>(body) {
325 Ok(json) => json,
326 Err(_) => return body.to_string(),
327 };
328
329 let message = json
330 .get("error")
331 .and_then(|value| {
332 value
333 .get("message")
334 .and_then(serde_json::Value::as_str)
335 .or_else(|| value.as_str())
336 })
337 .or_else(|| json.get("message").and_then(serde_json::Value::as_str))
338 .map(ToString::to_string)
339 .unwrap_or_else(|| body.to_string());
340
341 clean_error_message(&message)
342}
343
344fn clean_error_message(message: &str) -> String {
345 let lower = message.to_lowercase();
346
347 if lower.contains("vercel_oidc_token") && lower.contains("oidc token") {
348 return "Authentication failed for Vercel AI Gateway. Use a Vercel AI Gateway key (vck_...).\nCreate or manage keys in Vercel AI Gateway console.\nIf this persists, regenerate the key and update it in Vercel AI Gateway provider settings in Zed.".to_string();
349 }
350
351 if lower.contains("invalid api key") || lower.contains("invalid_api_key") {
352 return "Authentication failed for Vercel AI Gateway. Check that your Vercel AI Gateway key starts with vck_ and is active.".to_string();
353 }
354
355 message.to_string()
356}
357
358fn has_tag(tags: &[String], expected: &str) -> bool {
359 tags.iter()
360 .any(|tag| tag.trim().eq_ignore_ascii_case(expected))
361}
362
363impl LanguageModel for VercelAiGatewayLanguageModel {
364 fn id(&self) -> LanguageModelId {
365 self.id.clone()
366 }
367
368 fn name(&self) -> LanguageModelName {
369 LanguageModelName::from(
370 self.model
371 .display_name
372 .clone()
373 .unwrap_or_else(|| self.model.name.clone()),
374 )
375 }
376
377 fn provider_id(&self) -> LanguageModelProviderId {
378 PROVIDER_ID
379 }
380
381 fn provider_name(&self) -> LanguageModelProviderName {
382 PROVIDER_NAME
383 }
384
385 fn supports_tools(&self) -> bool {
386 self.model.capabilities.tools
387 }
388
389 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
390 LanguageModelToolSchemaFormat::JsonSchemaSubset
391 }
392
393 fn supports_images(&self) -> bool {
394 self.model.capabilities.images
395 }
396
397 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
398 match choice {
399 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
400 LanguageModelToolChoice::Any => self.model.capabilities.tools,
401 LanguageModelToolChoice::None => true,
402 }
403 }
404
405 fn supports_streaming_tools(&self) -> bool {
406 true
407 }
408
409 fn supports_split_token_display(&self) -> bool {
410 true
411 }
412
413 fn telemetry_id(&self) -> String {
414 format!("vercel_ai_gateway/{}", self.model.name)
415 }
416
417 fn max_token_count(&self) -> u64 {
418 self.model.max_tokens
419 }
420
421 fn max_output_tokens(&self) -> Option<u64> {
422 self.model.max_output_tokens
423 }
424
425 fn count_tokens(
426 &self,
427 request: LanguageModelRequest,
428 cx: &App,
429 ) -> BoxFuture<'static, Result<u64>> {
430 let max_token_count = self.max_token_count();
431 cx.background_spawn(async move {
432 let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
433 let model = if max_token_count >= 100_000 {
434 "gpt-4o"
435 } else {
436 "gpt-4"
437 };
438 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
439 })
440 .boxed()
441 }
442
443 fn stream_completion(
444 &self,
445 request: LanguageModelRequest,
446 cx: &AsyncApp,
447 ) -> BoxFuture<
448 'static,
449 Result<
450 futures::stream::BoxStream<
451 'static,
452 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
453 >,
454 LanguageModelCompletionError,
455 >,
456 > {
457 let request = crate::provider::open_ai::into_open_ai(
458 request,
459 &self.model.name,
460 self.model.capabilities.parallel_tool_calls,
461 self.model.capabilities.prompt_cache_key,
462 self.max_output_tokens(),
463 None,
464 );
465 let completions = self.stream_open_ai(request, cx);
466 async move {
467 let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
468 Ok(mapper.map_stream(completions.await?).boxed())
469 }
470 .boxed()
471 }
472}
473
474#[derive(Deserialize)]
475struct ModelsResponse {
476 data: Vec<ApiModel>,
477}
478
479#[derive(Deserialize)]
480struct ApiModel {
481 id: String,
482 name: Option<String>,
483 context_window: Option<u64>,
484 max_tokens: Option<u64>,
485 #[serde(default)]
486 r#type: Option<String>,
487 #[serde(default)]
488 supported_parameters: Vec<String>,
489 #[serde(default)]
490 tags: Vec<String>,
491 architecture: Option<ApiModelArchitecture>,
492}
493
494#[derive(Deserialize)]
495struct ApiModelArchitecture {
496 #[serde(default)]
497 input_modalities: Vec<String>,
498}
499
500async fn list_models(
501 client: &dyn HttpClient,
502 api_url: &str,
503 api_key: Option<&str>,
504) -> Result<Vec<AvailableModel>, LanguageModelCompletionError> {
505 let uri = format!("{api_url}/models?include_mappings=true");
506 let mut request_builder = HttpRequest::builder()
507 .method(Method::GET)
508 .uri(uri)
509 .header("Accept", "application/json");
510 if let Some(api_key) = api_key {
511 request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
512 }
513 let request = request_builder
514 .body(AsyncBody::default())
515 .map_err(|error| LanguageModelCompletionError::BuildRequestBody {
516 provider: PROVIDER_NAME,
517 error,
518 })?;
519 let mut response =
520 client
521 .send(request)
522 .await
523 .map_err(|error| LanguageModelCompletionError::HttpSend {
524 provider: PROVIDER_NAME,
525 error,
526 })?;
527
528 let mut body = String::new();
529 response
530 .body_mut()
531 .read_to_string(&mut body)
532 .await
533 .map_err(|error| LanguageModelCompletionError::ApiReadResponseError {
534 provider: PROVIDER_NAME,
535 error,
536 })?;
537
538 if !response.status().is_success() {
539 return Err(LanguageModelCompletionError::from_http_status(
540 PROVIDER_NAME,
541 response.status(),
542 extract_error_message(&body),
543 None,
544 ));
545 }
546
547 let response: ModelsResponse = serde_json::from_str(&body).map_err(|error| {
548 LanguageModelCompletionError::DeserializeResponse {
549 provider: PROVIDER_NAME,
550 error,
551 }
552 })?;
553
554 let mut models = Vec::new();
555 for model in response.data {
556 if let Some(model_type) = model.r#type.as_deref()
557 && model_type != "language"
558 {
559 continue;
560 }
561 let supports_tools = model
562 .supported_parameters
563 .iter()
564 .any(|parameter| parameter == "tools")
565 || has_tag(&model.tags, "tool-use")
566 || has_tag(&model.tags, "tools");
567 let supports_images = model.architecture.is_some_and(|architecture| {
568 architecture
569 .input_modalities
570 .iter()
571 .any(|modality| modality == "image")
572 }) || has_tag(&model.tags, "vision")
573 || has_tag(&model.tags, "image-input");
574 let parallel_tool_calls = model
575 .supported_parameters
576 .iter()
577 .any(|parameter| parameter == "parallel_tool_calls");
578 let prompt_cache_key = model
579 .supported_parameters
580 .iter()
581 .any(|parameter| parameter == "prompt_cache_key" || parameter == "cache_control");
582 models.push(AvailableModel {
583 name: model.id.clone(),
584 display_name: model.name.or(Some(model.id)),
585 max_tokens: model.context_window.or(model.max_tokens).unwrap_or(128_000),
586 max_output_tokens: model.max_tokens,
587 max_completion_tokens: None,
588 capabilities: ModelCapabilities {
589 tools: supports_tools,
590 images: supports_images,
591 parallel_tool_calls,
592 prompt_cache_key,
593 chat_completions: true,
594 },
595 });
596 }
597
598 Ok(models)
599}
600
601struct ConfigurationView {
602 api_key_editor: Entity<InputField>,
603 state: Entity<State>,
604 load_credentials_task: Option<Task<()>>,
605}
606
607impl ConfigurationView {
608 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
609 let api_key_editor =
610 cx.new(|cx| InputField::new(window, cx, "vck_000000000000000000000000000"));
611
612 cx.observe(&state, |_, _, cx| cx.notify()).detach();
613
614 let load_credentials_task = Some(cx.spawn_in(window, {
615 let state = state.clone();
616 async move |this, cx| {
617 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
618 let _ = task.await;
619 }
620 this.update(cx, |this, cx| {
621 this.load_credentials_task = None;
622 cx.notify();
623 })
624 .log_err();
625 }
626 }));
627
628 Self {
629 api_key_editor,
630 state,
631 load_credentials_task,
632 }
633 }
634
635 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
636 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
637 if api_key.is_empty() {
638 return;
639 }
640
641 self.api_key_editor
642 .update(cx, |editor, cx| editor.set_text("", window, cx));
643
644 let state = self.state.clone();
645 cx.spawn_in(window, async move |_, cx| {
646 state
647 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
648 .await
649 })
650 .detach_and_log_err(cx);
651 }
652
653 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
654 self.api_key_editor
655 .update(cx, |editor, cx| editor.set_text("", window, cx));
656
657 let state = self.state.clone();
658 cx.spawn_in(window, async move |_, cx| {
659 state
660 .update(cx, |state, cx| state.set_api_key(None, cx))
661 .await
662 })
663 .detach_and_log_err(cx);
664 }
665
666 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
667 !self.state.read(cx).is_authenticated()
668 }
669}
670
671impl Render for ConfigurationView {
672 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
673 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
674 let configured_card_label = if env_var_set {
675 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
676 } else {
677 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
678 if api_url == API_URL {
679 "API key configured".to_string()
680 } else {
681 format!("API key configured for {}", api_url)
682 }
683 };
684
685 if self.load_credentials_task.is_some() {
686 div().child(Label::new("Loading credentials...")).into_any()
687 } else if self.should_render_editor(cx) {
688 v_flex()
689 .size_full()
690 .on_action(cx.listener(Self::save_api_key))
691 .child(Label::new(
692 "To use Zed's agent with Vercel AI Gateway, you need to add an API key. Follow these steps:",
693 ))
694 .child(
695 List::new()
696 .child(
697 ListBulletItem::new("")
698 .child(Label::new("Create an API key in"))
699 .child(ButtonLink::new(
700 "Vercel AI Gateway's console",
701 "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys&title=Go+to+AI+Gateway",
702 )),
703 )
704 .child(ListBulletItem::new(
705 "Paste your API key below and hit enter to start using the assistant",
706 )),
707 )
708 .child(self.api_key_editor.clone())
709 .child(
710 Label::new(format!(
711 "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.",
712 ))
713 .size(LabelSize::Small)
714 .color(Color::Muted),
715 )
716 .into_any_element()
717 } else {
718 ConfiguredApiCard::new(configured_card_label)
719 .disabled(env_var_set)
720 .when(env_var_set, |this| {
721 this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
722 })
723 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
724 .into_any_element()
725 }
726 }
727}