1use anyhow::Result;
2use collections::BTreeMap;
3use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
4use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
6use language_model::{
7 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
8 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
10 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
11 env_var,
12};
13use open_ai::ResponseStreamEvent;
14use serde::Deserialize;
15pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
16pub use settings::VercelAiGatewayAvailableModel as AvailableModel;
17use settings::{Settings, SettingsStore};
18use std::sync::{Arc, LazyLock};
19use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
20use ui_input::InputField;
21use util::ResultExt;
22
23const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel_ai_gateway");
24const PROVIDER_NAME: LanguageModelProviderName =
25 LanguageModelProviderName::new("Vercel AI Gateway");
26
27const API_URL: &str = "https://ai-gateway.vercel.sh/v1";
28const API_KEY_ENV_VAR_NAME: &str = "VERCEL_AI_GATEWAY_API_KEY";
29static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
30
31#[derive(Default, Clone, Debug, PartialEq)]
32pub struct VercelAiGatewaySettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37pub struct VercelAiGatewayLanguageModelProvider {
38 http_client: Arc<dyn HttpClient>,
39 state: Entity<State>,
40}
41
42pub struct State {
43 api_key_state: ApiKeyState,
44 http_client: Arc<dyn HttpClient>,
45 available_models: Vec<AvailableModel>,
46 fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
47}
48
49impl State {
50 fn is_authenticated(&self) -> bool {
51 self.api_key_state.has_key()
52 }
53
54 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
55 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
56 self.api_key_state
57 .store(api_url, api_key, |this| &mut this.api_key_state, cx)
58 }
59
60 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
61 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
62 let task = self
63 .api_key_state
64 .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
65
66 cx.spawn(async move |this, cx| {
67 let result = task.await;
68 this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
69 .ok();
70 result
71 })
72 }
73
74 fn fetch_models(
75 &mut self,
76 cx: &mut Context<Self>,
77 ) -> Task<Result<(), LanguageModelCompletionError>> {
78 let http_client = self.http_client.clone();
79 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
80 let api_key = self.api_key_state.key(&api_url);
81 cx.spawn(async move |this, cx| {
82 let models = list_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
83 this.update(cx, |this, cx| {
84 this.available_models = models;
85 cx.notify();
86 })
87 .map_err(|e| LanguageModelCompletionError::Other(e))?;
88 Ok(())
89 })
90 }
91
92 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
93 if self.is_authenticated() {
94 let task = self.fetch_models(cx);
95 self.fetch_models_task.replace(task);
96 } else {
97 self.available_models = Vec::new();
98 }
99 }
100}
101
102impl VercelAiGatewayLanguageModelProvider {
103 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
104 let state = cx.new(|cx| {
105 cx.observe_global::<SettingsStore>({
106 let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
107 move |this: &mut State, cx| {
108 let current_settings = VercelAiGatewayLanguageModelProvider::settings(cx);
109 if current_settings != &last_settings {
110 last_settings = current_settings.clone();
111 this.authenticate(cx).detach();
112 cx.notify();
113 }
114 }
115 })
116 .detach();
117 State {
118 api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
119 http_client: http_client.clone(),
120 available_models: Vec::new(),
121 fetch_models_task: None,
122 }
123 });
124
125 Self { http_client, state }
126 }
127
128 fn settings(cx: &App) -> &VercelAiGatewaySettings {
129 &crate::AllLanguageModelSettings::get_global(cx).vercel_ai_gateway
130 }
131
132 fn api_url(cx: &App) -> SharedString {
133 let api_url = &Self::settings(cx).api_url;
134 if api_url.is_empty() {
135 API_URL.into()
136 } else {
137 SharedString::new(api_url.as_str())
138 }
139 }
140
141 fn default_available_model() -> AvailableModel {
142 AvailableModel {
143 name: "openai/gpt-5.3-codex".to_string(),
144 display_name: Some("GPT 5.3 Codex".to_string()),
145 max_tokens: 400_000,
146 max_output_tokens: Some(128_000),
147 max_completion_tokens: None,
148 capabilities: ModelCapabilities::default(),
149 }
150 }
151
152 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
153 Arc::new(VercelAiGatewayLanguageModel {
154 id: LanguageModelId::from(model.name.clone()),
155 model,
156 state: self.state.clone(),
157 http_client: self.http_client.clone(),
158 request_limiter: RateLimiter::new(4),
159 })
160 }
161}
162
163impl LanguageModelProviderState for VercelAiGatewayLanguageModelProvider {
164 type ObservableEntity = State;
165
166 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
167 Some(self.state.clone())
168 }
169}
170
171impl LanguageModelProvider for VercelAiGatewayLanguageModelProvider {
172 fn id(&self) -> LanguageModelProviderId {
173 PROVIDER_ID
174 }
175
176 fn name(&self) -> LanguageModelProviderName {
177 PROVIDER_NAME
178 }
179
180 fn icon(&self) -> IconOrSvg {
181 IconOrSvg::Icon(IconName::AiVercel)
182 }
183
184 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
185 Some(self.create_language_model(Self::default_available_model()))
186 }
187
188 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
189 None
190 }
191
192 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
193 let mut models = BTreeMap::default();
194
195 let default_model = Self::default_available_model();
196 models.insert(default_model.name.clone(), default_model);
197
198 for model in self.state.read(cx).available_models.clone() {
199 models.insert(model.name.clone(), model);
200 }
201
202 for model in &Self::settings(cx).available_models {
203 models.insert(model.name.clone(), model.clone());
204 }
205
206 models
207 .into_values()
208 .map(|model| self.create_language_model(model))
209 .collect()
210 }
211
212 fn is_authenticated(&self, cx: &App) -> bool {
213 self.state.read(cx).is_authenticated()
214 }
215
216 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
217 self.state.update(cx, |state, cx| state.authenticate(cx))
218 }
219
220 fn configuration_view(
221 &self,
222 _target_agent: language_model::ConfigurationViewTargetAgent,
223 window: &mut Window,
224 cx: &mut App,
225 ) -> AnyView {
226 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
227 .into()
228 }
229
230 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
231 self.state
232 .update(cx, |state, cx| state.set_api_key(None, cx))
233 }
234}
235
236pub struct VercelAiGatewayLanguageModel {
237 id: LanguageModelId,
238 model: AvailableModel,
239 state: Entity<State>,
240 http_client: Arc<dyn HttpClient>,
241 request_limiter: RateLimiter,
242}
243
244impl VercelAiGatewayLanguageModel {
245 fn stream_open_ai(
246 &self,
247 request: open_ai::Request,
248 cx: &AsyncApp,
249 ) -> BoxFuture<
250 'static,
251 Result<
252 futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
253 LanguageModelCompletionError,
254 >,
255 > {
256 let http_client = self.http_client.clone();
257 let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
258 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
259 (state.api_key_state.key(&api_url), api_url)
260 });
261
262 let future = self.request_limiter.stream(async move {
263 let provider = PROVIDER_NAME;
264 let Some(api_key) = api_key else {
265 return Err(LanguageModelCompletionError::NoApiKey { provider });
266 };
267 let request = open_ai::stream_completion(
268 http_client.as_ref(),
269 provider.0.as_str(),
270 &api_url,
271 &api_key,
272 request,
273 );
274 let response = request.await.map_err(map_open_ai_error)?;
275 Ok(response)
276 });
277
278 async move { Ok(future.await?.boxed()) }.boxed()
279 }
280}
281
282fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionError {
283 match error {
284 open_ai::RequestError::HttpResponseError {
285 status_code,
286 body,
287 headers,
288 ..
289 } => {
290 let retry_after = headers
291 .get(http::header::RETRY_AFTER)
292 .and_then(|value| value.to_str().ok()?.parse::<u64>().ok())
293 .map(std::time::Duration::from_secs);
294
295 LanguageModelCompletionError::from_http_status(
296 PROVIDER_NAME,
297 status_code,
298 extract_error_message(&body),
299 retry_after,
300 )
301 }
302 open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
303 }
304}
305
306fn extract_error_message(body: &str) -> String {
307 let json = match serde_json::from_str::<serde_json::Value>(body) {
308 Ok(json) => json,
309 Err(_) => return body.to_string(),
310 };
311
312 let message = json
313 .get("error")
314 .and_then(|value| {
315 value
316 .get("message")
317 .and_then(serde_json::Value::as_str)
318 .or_else(|| value.as_str())
319 })
320 .or_else(|| json.get("message").and_then(serde_json::Value::as_str))
321 .map(ToString::to_string)
322 .unwrap_or_else(|| body.to_string());
323
324 clean_error_message(&message)
325}
326
327fn clean_error_message(message: &str) -> String {
328 let lower = message.to_lowercase();
329
330 if lower.contains("vercel_oidc_token") && lower.contains("oidc token") {
331 return "Authentication failed for Vercel AI Gateway. Use a Vercel AI Gateway key (vck_...).\nCreate or manage keys in Vercel AI Gateway console.\nIf this persists, regenerate the key and update it in Vercel AI Gateway provider settings in Zed.".to_string();
332 }
333
334 if lower.contains("invalid api key") || lower.contains("invalid_api_key") {
335 return "Authentication failed for Vercel AI Gateway. Check that your Vercel AI Gateway key starts with vck_ and is active.".to_string();
336 }
337
338 message.to_string()
339}
340
341fn has_tag(tags: &[String], expected: &str) -> bool {
342 tags.iter()
343 .any(|tag| tag.trim().eq_ignore_ascii_case(expected))
344}
345
346impl LanguageModel for VercelAiGatewayLanguageModel {
347 fn id(&self) -> LanguageModelId {
348 self.id.clone()
349 }
350
351 fn name(&self) -> LanguageModelName {
352 LanguageModelName::from(
353 self.model
354 .display_name
355 .clone()
356 .unwrap_or_else(|| self.model.name.clone()),
357 )
358 }
359
360 fn provider_id(&self) -> LanguageModelProviderId {
361 PROVIDER_ID
362 }
363
364 fn provider_name(&self) -> LanguageModelProviderName {
365 PROVIDER_NAME
366 }
367
368 fn supports_tools(&self) -> bool {
369 self.model.capabilities.tools
370 }
371
372 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
373 LanguageModelToolSchemaFormat::JsonSchemaSubset
374 }
375
376 fn supports_images(&self) -> bool {
377 self.model.capabilities.images
378 }
379
380 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
381 match choice {
382 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
383 LanguageModelToolChoice::Any => self.model.capabilities.tools,
384 LanguageModelToolChoice::None => true,
385 }
386 }
387
388 fn supports_streaming_tools(&self) -> bool {
389 true
390 }
391
392 fn supports_split_token_display(&self) -> bool {
393 true
394 }
395
396 fn telemetry_id(&self) -> String {
397 format!("vercel_ai_gateway/{}", self.model.name)
398 }
399
400 fn max_token_count(&self) -> u64 {
401 self.model.max_tokens
402 }
403
404 fn max_output_tokens(&self) -> Option<u64> {
405 self.model.max_output_tokens
406 }
407
408 fn count_tokens(
409 &self,
410 request: LanguageModelRequest,
411 cx: &App,
412 ) -> BoxFuture<'static, Result<u64>> {
413 let max_token_count = self.max_token_count();
414 cx.background_spawn(async move {
415 let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
416 let model = if max_token_count >= 100_000 {
417 "gpt-4o"
418 } else {
419 "gpt-4"
420 };
421 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
422 })
423 .boxed()
424 }
425
426 fn stream_completion(
427 &self,
428 request: LanguageModelRequest,
429 cx: &AsyncApp,
430 ) -> BoxFuture<
431 'static,
432 Result<
433 futures::stream::BoxStream<
434 'static,
435 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
436 >,
437 LanguageModelCompletionError,
438 >,
439 > {
440 let request = crate::provider::open_ai::into_open_ai(
441 request,
442 &self.model.name,
443 self.model.capabilities.parallel_tool_calls,
444 self.model.capabilities.prompt_cache_key,
445 self.max_output_tokens(),
446 None,
447 );
448 let completions = self.stream_open_ai(request, cx);
449 async move {
450 let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
451 Ok(mapper.map_stream(completions.await?).boxed())
452 }
453 .boxed()
454 }
455}
456
457#[derive(Deserialize)]
458struct ModelsResponse {
459 data: Vec<ApiModel>,
460}
461
462#[derive(Deserialize)]
463struct ApiModel {
464 id: String,
465 name: Option<String>,
466 context_window: Option<u64>,
467 max_tokens: Option<u64>,
468 #[serde(default)]
469 r#type: Option<String>,
470 #[serde(default)]
471 supported_parameters: Vec<String>,
472 #[serde(default)]
473 tags: Vec<String>,
474 architecture: Option<ApiModelArchitecture>,
475}
476
477#[derive(Deserialize)]
478struct ApiModelArchitecture {
479 #[serde(default)]
480 input_modalities: Vec<String>,
481}
482
483async fn list_models(
484 client: &dyn HttpClient,
485 api_url: &str,
486 api_key: Option<&str>,
487) -> Result<Vec<AvailableModel>, LanguageModelCompletionError> {
488 let uri = format!("{api_url}/models?include_mappings=true");
489 let mut request_builder = HttpRequest::builder()
490 .method(Method::GET)
491 .uri(uri)
492 .header("Accept", "application/json");
493 if let Some(api_key) = api_key {
494 request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
495 }
496 let request = request_builder
497 .body(AsyncBody::default())
498 .map_err(|error| LanguageModelCompletionError::BuildRequestBody {
499 provider: PROVIDER_NAME,
500 error,
501 })?;
502 let mut response =
503 client
504 .send(request)
505 .await
506 .map_err(|error| LanguageModelCompletionError::HttpSend {
507 provider: PROVIDER_NAME,
508 error,
509 })?;
510
511 let mut body = String::new();
512 response
513 .body_mut()
514 .read_to_string(&mut body)
515 .await
516 .map_err(|error| LanguageModelCompletionError::ApiReadResponseError {
517 provider: PROVIDER_NAME,
518 error,
519 })?;
520
521 if !response.status().is_success() {
522 return Err(LanguageModelCompletionError::from_http_status(
523 PROVIDER_NAME,
524 response.status(),
525 extract_error_message(&body),
526 None,
527 ));
528 }
529
530 let response: ModelsResponse = serde_json::from_str(&body).map_err(|error| {
531 LanguageModelCompletionError::DeserializeResponse {
532 provider: PROVIDER_NAME,
533 error,
534 }
535 })?;
536
537 let mut models = Vec::new();
538 for model in response.data {
539 if let Some(model_type) = model.r#type.as_deref()
540 && model_type != "language"
541 {
542 continue;
543 }
544 let supports_tools = model
545 .supported_parameters
546 .iter()
547 .any(|parameter| parameter == "tools")
548 || has_tag(&model.tags, "tool-use")
549 || has_tag(&model.tags, "tools");
550 let supports_images = model.architecture.is_some_and(|architecture| {
551 architecture
552 .input_modalities
553 .iter()
554 .any(|modality| modality == "image")
555 }) || has_tag(&model.tags, "vision")
556 || has_tag(&model.tags, "image-input");
557 let parallel_tool_calls = model
558 .supported_parameters
559 .iter()
560 .any(|parameter| parameter == "parallel_tool_calls");
561 let prompt_cache_key = model
562 .supported_parameters
563 .iter()
564 .any(|parameter| parameter == "prompt_cache_key" || parameter == "cache_control");
565 models.push(AvailableModel {
566 name: model.id.clone(),
567 display_name: model.name.or(Some(model.id)),
568 max_tokens: model.context_window.or(model.max_tokens).unwrap_or(128_000),
569 max_output_tokens: model.max_tokens,
570 max_completion_tokens: None,
571 capabilities: ModelCapabilities {
572 tools: supports_tools,
573 images: supports_images,
574 parallel_tool_calls,
575 prompt_cache_key,
576 chat_completions: true,
577 },
578 });
579 }
580
581 Ok(models)
582}
583
584struct ConfigurationView {
585 api_key_editor: Entity<InputField>,
586 state: Entity<State>,
587 load_credentials_task: Option<Task<()>>,
588}
589
590impl ConfigurationView {
591 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
592 let api_key_editor =
593 cx.new(|cx| InputField::new(window, cx, "vck_000000000000000000000000000"));
594
595 cx.observe(&state, |_, _, cx| cx.notify()).detach();
596
597 let load_credentials_task = Some(cx.spawn_in(window, {
598 let state = state.clone();
599 async move |this, cx| {
600 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
601 let _ = task.await;
602 }
603 this.update(cx, |this, cx| {
604 this.load_credentials_task = None;
605 cx.notify();
606 })
607 .log_err();
608 }
609 }));
610
611 Self {
612 api_key_editor,
613 state,
614 load_credentials_task,
615 }
616 }
617
618 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
619 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
620 if api_key.is_empty() {
621 return;
622 }
623
624 self.api_key_editor
625 .update(cx, |editor, cx| editor.set_text("", window, cx));
626
627 let state = self.state.clone();
628 cx.spawn_in(window, async move |_, cx| {
629 state
630 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
631 .await
632 })
633 .detach_and_log_err(cx);
634 }
635
636 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
637 self.api_key_editor
638 .update(cx, |editor, cx| editor.set_text("", window, cx));
639
640 let state = self.state.clone();
641 cx.spawn_in(window, async move |_, cx| {
642 state
643 .update(cx, |state, cx| state.set_api_key(None, cx))
644 .await
645 })
646 .detach_and_log_err(cx);
647 }
648
649 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
650 !self.state.read(cx).is_authenticated()
651 }
652}
653
654impl Render for ConfigurationView {
655 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
656 let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
657 let configured_card_label = if env_var_set {
658 format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
659 } else {
660 let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
661 if api_url == API_URL {
662 "API key configured".to_string()
663 } else {
664 format!("API key configured for {}", api_url)
665 }
666 };
667
668 if self.load_credentials_task.is_some() {
669 div().child(Label::new("Loading credentials...")).into_any()
670 } else if self.should_render_editor(cx) {
671 v_flex()
672 .size_full()
673 .on_action(cx.listener(Self::save_api_key))
674 .child(Label::new(
675 "To use Zed's agent with Vercel AI Gateway, you need to add an API key. Follow these steps:",
676 ))
677 .child(
678 List::new()
679 .child(
680 ListBulletItem::new("")
681 .child(Label::new("Create an API key in"))
682 .child(ButtonLink::new(
683 "Vercel AI Gateway's console",
684 "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys&title=Go+to+AI+Gateway",
685 )),
686 )
687 .child(ListBulletItem::new(
688 "Paste your API key below and hit enter to start using the assistant",
689 )),
690 )
691 .child(self.api_key_editor.clone())
692 .child(
693 Label::new(format!(
694 "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.",
695 ))
696 .size(LabelSize::Small)
697 .color(Color::Muted),
698 )
699 .into_any_element()
700 } else {
701 ConfiguredApiCard::new(configured_card_label)
702 .disabled(env_var_set)
703 .when(env_var_set, |this| {
704 this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
705 })
706 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
707 .into_any_element()
708 }
709 }
710}