vercel_ai_gateway.rs

  1use anyhow::Result;
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
  5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  7use language_model::{
  8    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  9    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 11    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
 12    env_var,
 13};
 14use open_ai::ResponseStreamEvent;
 15use serde::Deserialize;
 16pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
 17pub use settings::VercelAiGatewayAvailableModel as AvailableModel;
 18use settings::{Settings, SettingsStore};
 19use std::sync::{Arc, LazyLock};
 20use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 21use ui_input::InputField;
 22use util::ResultExt;
 23
 24const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel_ai_gateway");
 25const PROVIDER_NAME: LanguageModelProviderName =
 26    LanguageModelProviderName::new("Vercel AI Gateway");
 27
 28const API_URL: &str = "https://ai-gateway.vercel.sh/v1";
 29const API_KEY_ENV_VAR_NAME: &str = "VERCEL_AI_GATEWAY_API_KEY";
 30static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 31
 32#[derive(Default, Clone, Debug, PartialEq)]
 33pub struct VercelAiGatewaySettings {
 34    pub api_url: String,
 35    pub available_models: Vec<AvailableModel>,
 36}
 37
 38pub struct VercelAiGatewayLanguageModelProvider {
 39    http_client: Arc<dyn HttpClient>,
 40    state: Entity<State>,
 41}
 42
 43pub struct State {
 44    api_key_state: ApiKeyState,
 45    credentials_provider: Arc<dyn CredentialsProvider>,
 46    http_client: Arc<dyn HttpClient>,
 47    available_models: Vec<AvailableModel>,
 48    fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
 49}
 50
 51impl State {
 52    fn is_authenticated(&self) -> bool {
 53        self.api_key_state.has_key()
 54    }
 55
 56    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 57        let credentials_provider = self.credentials_provider.clone();
 58        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 59        self.api_key_state.store(
 60            api_url,
 61            api_key,
 62            |this| &mut this.api_key_state,
 63            credentials_provider,
 64            cx,
 65        )
 66    }
 67
 68    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 69        let credentials_provider = self.credentials_provider.clone();
 70        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 71        let task = self.api_key_state.load_if_needed(
 72            api_url,
 73            |this| &mut this.api_key_state,
 74            credentials_provider,
 75            cx,
 76        );
 77
 78        cx.spawn(async move |this, cx| {
 79            let result = task.await;
 80            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
 81                .ok();
 82            result
 83        })
 84    }
 85
 86    fn fetch_models(
 87        &mut self,
 88        cx: &mut Context<Self>,
 89    ) -> Task<Result<(), LanguageModelCompletionError>> {
 90        let http_client = self.http_client.clone();
 91        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 92        let api_key = self.api_key_state.key(&api_url);
 93        cx.spawn(async move |this, cx| {
 94            let models = list_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
 95            this.update(cx, |this, cx| {
 96                this.available_models = models;
 97                cx.notify();
 98            })
 99            .map_err(|e| LanguageModelCompletionError::Other(e))?;
100            Ok(())
101        })
102    }
103
104    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
105        if self.is_authenticated() {
106            let task = self.fetch_models(cx);
107            self.fetch_models_task.replace(task);
108        } else {
109            self.available_models = Vec::new();
110        }
111    }
112}
113
114impl VercelAiGatewayLanguageModelProvider {
115    pub fn new(
116        http_client: Arc<dyn HttpClient>,
117        credentials_provider: Arc<dyn CredentialsProvider>,
118        cx: &mut App,
119    ) -> Self {
120        let state = cx.new(|cx| {
121            cx.observe_global::<SettingsStore>({
122                let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
123                move |this: &mut State, cx| {
124                    let current_settings = VercelAiGatewayLanguageModelProvider::settings(cx);
125                    if current_settings != &last_settings {
126                        last_settings = current_settings.clone();
127                        this.authenticate(cx).detach();
128                        cx.notify();
129                    }
130                }
131            })
132            .detach();
133            State {
134                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
135                credentials_provider,
136                http_client: http_client.clone(),
137                available_models: Vec::new(),
138                fetch_models_task: None,
139            }
140        });
141
142        Self { http_client, state }
143    }
144
145    fn settings(cx: &App) -> &VercelAiGatewaySettings {
146        &crate::AllLanguageModelSettings::get_global(cx).vercel_ai_gateway
147    }
148
149    fn api_url(cx: &App) -> SharedString {
150        let api_url = &Self::settings(cx).api_url;
151        if api_url.is_empty() {
152            API_URL.into()
153        } else {
154            SharedString::new(api_url.as_str())
155        }
156    }
157
158    fn default_available_model() -> AvailableModel {
159        AvailableModel {
160            name: "openai/gpt-5.3-codex".to_string(),
161            display_name: Some("GPT 5.3 Codex".to_string()),
162            max_tokens: 400_000,
163            max_output_tokens: Some(128_000),
164            max_completion_tokens: None,
165            capabilities: ModelCapabilities::default(),
166        }
167    }
168
169    fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
170        Arc::new(VercelAiGatewayLanguageModel {
171            id: LanguageModelId::from(model.name.clone()),
172            model,
173            state: self.state.clone(),
174            http_client: self.http_client.clone(),
175            request_limiter: RateLimiter::new(4),
176        })
177    }
178}
179
180impl LanguageModelProviderState for VercelAiGatewayLanguageModelProvider {
181    type ObservableEntity = State;
182
183    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
184        Some(self.state.clone())
185    }
186}
187
188impl LanguageModelProvider for VercelAiGatewayLanguageModelProvider {
189    fn id(&self) -> LanguageModelProviderId {
190        PROVIDER_ID
191    }
192
193    fn name(&self) -> LanguageModelProviderName {
194        PROVIDER_NAME
195    }
196
197    fn icon(&self) -> IconOrSvg {
198        IconOrSvg::Icon(IconName::AiVercel)
199    }
200
201    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
202        Some(self.create_language_model(Self::default_available_model()))
203    }
204
205    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
206        None
207    }
208
209    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
210        let mut models = BTreeMap::default();
211
212        let default_model = Self::default_available_model();
213        models.insert(default_model.name.clone(), default_model);
214
215        for model in self.state.read(cx).available_models.clone() {
216            models.insert(model.name.clone(), model);
217        }
218
219        for model in &Self::settings(cx).available_models {
220            models.insert(model.name.clone(), model.clone());
221        }
222
223        models
224            .into_values()
225            .map(|model| self.create_language_model(model))
226            .collect()
227    }
228
229    fn is_authenticated(&self, cx: &App) -> bool {
230        self.state.read(cx).is_authenticated()
231    }
232
233    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
234        self.state.update(cx, |state, cx| state.authenticate(cx))
235    }
236
237    fn configuration_view(
238        &self,
239        _target_agent: language_model::ConfigurationViewTargetAgent,
240        window: &mut Window,
241        cx: &mut App,
242    ) -> AnyView {
243        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
244            .into()
245    }
246
247    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
248        self.state
249            .update(cx, |state, cx| state.set_api_key(None, cx))
250    }
251}
252
253pub struct VercelAiGatewayLanguageModel {
254    id: LanguageModelId,
255    model: AvailableModel,
256    state: Entity<State>,
257    http_client: Arc<dyn HttpClient>,
258    request_limiter: RateLimiter,
259}
260
261impl VercelAiGatewayLanguageModel {
262    fn stream_open_ai(
263        &self,
264        request: open_ai::Request,
265        cx: &AsyncApp,
266    ) -> BoxFuture<
267        'static,
268        Result<
269            futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
270            LanguageModelCompletionError,
271        >,
272    > {
273        let http_client = self.http_client.clone();
274        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
275            let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
276            (state.api_key_state.key(&api_url), api_url)
277        });
278
279        let future = self.request_limiter.stream(async move {
280            let provider = PROVIDER_NAME;
281            let Some(api_key) = api_key else {
282                return Err(LanguageModelCompletionError::NoApiKey { provider });
283            };
284            let request = open_ai::stream_completion(
285                http_client.as_ref(),
286                provider.0.as_str(),
287                &api_url,
288                &api_key,
289                request,
290            );
291            let response = request.await.map_err(map_open_ai_error)?;
292            Ok(response)
293        });
294
295        async move { Ok(future.await?.boxed()) }.boxed()
296    }
297}
298
299fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionError {
300    match error {
301        open_ai::RequestError::HttpResponseError {
302            status_code,
303            body,
304            headers,
305            ..
306        } => {
307            let retry_after = headers
308                .get(http::header::RETRY_AFTER)
309                .and_then(|value| value.to_str().ok()?.parse::<u64>().ok())
310                .map(std::time::Duration::from_secs);
311
312            LanguageModelCompletionError::from_http_status(
313                PROVIDER_NAME,
314                status_code,
315                extract_error_message(&body),
316                retry_after,
317            )
318        }
319        open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
320    }
321}
322
323fn extract_error_message(body: &str) -> String {
324    let json = match serde_json::from_str::<serde_json::Value>(body) {
325        Ok(json) => json,
326        Err(_) => return body.to_string(),
327    };
328
329    let message = json
330        .get("error")
331        .and_then(|value| {
332            value
333                .get("message")
334                .and_then(serde_json::Value::as_str)
335                .or_else(|| value.as_str())
336        })
337        .or_else(|| json.get("message").and_then(serde_json::Value::as_str))
338        .map(ToString::to_string)
339        .unwrap_or_else(|| body.to_string());
340
341    clean_error_message(&message)
342}
343
344fn clean_error_message(message: &str) -> String {
345    let lower = message.to_lowercase();
346
347    if lower.contains("vercel_oidc_token") && lower.contains("oidc token") {
348        return "Authentication failed for Vercel AI Gateway. Use a Vercel AI Gateway key (vck_...).\nCreate or manage keys in Vercel AI Gateway console.\nIf this persists, regenerate the key and update it in Vercel AI Gateway provider settings in Zed.".to_string();
349    }
350
351    if lower.contains("invalid api key") || lower.contains("invalid_api_key") {
352        return "Authentication failed for Vercel AI Gateway. Check that your Vercel AI Gateway key starts with vck_ and is active.".to_string();
353    }
354
355    message.to_string()
356}
357
358fn has_tag(tags: &[String], expected: &str) -> bool {
359    tags.iter()
360        .any(|tag| tag.trim().eq_ignore_ascii_case(expected))
361}
362
363impl LanguageModel for VercelAiGatewayLanguageModel {
364    fn id(&self) -> LanguageModelId {
365        self.id.clone()
366    }
367
368    fn name(&self) -> LanguageModelName {
369        LanguageModelName::from(
370            self.model
371                .display_name
372                .clone()
373                .unwrap_or_else(|| self.model.name.clone()),
374        )
375    }
376
377    fn provider_id(&self) -> LanguageModelProviderId {
378        PROVIDER_ID
379    }
380
381    fn provider_name(&self) -> LanguageModelProviderName {
382        PROVIDER_NAME
383    }
384
385    fn supports_tools(&self) -> bool {
386        self.model.capabilities.tools
387    }
388
389    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
390        LanguageModelToolSchemaFormat::JsonSchemaSubset
391    }
392
393    fn supports_images(&self) -> bool {
394        self.model.capabilities.images
395    }
396
397    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
398        match choice {
399            LanguageModelToolChoice::Auto => self.model.capabilities.tools,
400            LanguageModelToolChoice::Any => self.model.capabilities.tools,
401            LanguageModelToolChoice::None => true,
402        }
403    }
404
405    fn supports_streaming_tools(&self) -> bool {
406        true
407    }
408
409    fn supports_split_token_display(&self) -> bool {
410        true
411    }
412
413    fn telemetry_id(&self) -> String {
414        format!("vercel_ai_gateway/{}", self.model.name)
415    }
416
417    fn max_token_count(&self) -> u64 {
418        self.model.max_tokens
419    }
420
421    fn max_output_tokens(&self) -> Option<u64> {
422        self.model.max_output_tokens
423    }
424
425    fn count_tokens(
426        &self,
427        request: LanguageModelRequest,
428        cx: &App,
429    ) -> BoxFuture<'static, Result<u64>> {
430        let max_token_count = self.max_token_count();
431        cx.background_spawn(async move {
432            let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
433            let model = if max_token_count >= 100_000 {
434                "gpt-4o"
435            } else {
436                "gpt-4"
437            };
438            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
439        })
440        .boxed()
441    }
442
443    fn stream_completion(
444        &self,
445        request: LanguageModelRequest,
446        cx: &AsyncApp,
447    ) -> BoxFuture<
448        'static,
449        Result<
450            futures::stream::BoxStream<
451                'static,
452                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
453            >,
454            LanguageModelCompletionError,
455        >,
456    > {
457        let request = crate::provider::open_ai::into_open_ai(
458            request,
459            &self.model.name,
460            self.model.capabilities.parallel_tool_calls,
461            self.model.capabilities.prompt_cache_key,
462            self.max_output_tokens(),
463            None,
464        );
465        let completions = self.stream_open_ai(request, cx);
466        async move {
467            let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
468            Ok(mapper.map_stream(completions.await?).boxed())
469        }
470        .boxed()
471    }
472}
473
474#[derive(Deserialize)]
475struct ModelsResponse {
476    data: Vec<ApiModel>,
477}
478
479#[derive(Deserialize)]
480struct ApiModel {
481    id: String,
482    name: Option<String>,
483    context_window: Option<u64>,
484    max_tokens: Option<u64>,
485    #[serde(default)]
486    r#type: Option<String>,
487    #[serde(default)]
488    supported_parameters: Vec<String>,
489    #[serde(default)]
490    tags: Vec<String>,
491    architecture: Option<ApiModelArchitecture>,
492}
493
494#[derive(Deserialize)]
495struct ApiModelArchitecture {
496    #[serde(default)]
497    input_modalities: Vec<String>,
498}
499
500async fn list_models(
501    client: &dyn HttpClient,
502    api_url: &str,
503    api_key: Option<&str>,
504) -> Result<Vec<AvailableModel>, LanguageModelCompletionError> {
505    let uri = format!("{api_url}/models?include_mappings=true");
506    let mut request_builder = HttpRequest::builder()
507        .method(Method::GET)
508        .uri(uri)
509        .header("Accept", "application/json");
510    if let Some(api_key) = api_key {
511        request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
512    }
513    let request = request_builder
514        .body(AsyncBody::default())
515        .map_err(|error| LanguageModelCompletionError::BuildRequestBody {
516            provider: PROVIDER_NAME,
517            error,
518        })?;
519    let mut response =
520        client
521            .send(request)
522            .await
523            .map_err(|error| LanguageModelCompletionError::HttpSend {
524                provider: PROVIDER_NAME,
525                error,
526            })?;
527
528    let mut body = String::new();
529    response
530        .body_mut()
531        .read_to_string(&mut body)
532        .await
533        .map_err(|error| LanguageModelCompletionError::ApiReadResponseError {
534            provider: PROVIDER_NAME,
535            error,
536        })?;
537
538    if !response.status().is_success() {
539        return Err(LanguageModelCompletionError::from_http_status(
540            PROVIDER_NAME,
541            response.status(),
542            extract_error_message(&body),
543            None,
544        ));
545    }
546
547    let response: ModelsResponse = serde_json::from_str(&body).map_err(|error| {
548        LanguageModelCompletionError::DeserializeResponse {
549            provider: PROVIDER_NAME,
550            error,
551        }
552    })?;
553
554    let mut models = Vec::new();
555    for model in response.data {
556        if let Some(model_type) = model.r#type.as_deref()
557            && model_type != "language"
558        {
559            continue;
560        }
561        let supports_tools = model
562            .supported_parameters
563            .iter()
564            .any(|parameter| parameter == "tools")
565            || has_tag(&model.tags, "tool-use")
566            || has_tag(&model.tags, "tools");
567        let supports_images = model.architecture.is_some_and(|architecture| {
568            architecture
569                .input_modalities
570                .iter()
571                .any(|modality| modality == "image")
572        }) || has_tag(&model.tags, "vision")
573            || has_tag(&model.tags, "image-input");
574        let parallel_tool_calls = model
575            .supported_parameters
576            .iter()
577            .any(|parameter| parameter == "parallel_tool_calls");
578        let prompt_cache_key = model
579            .supported_parameters
580            .iter()
581            .any(|parameter| parameter == "prompt_cache_key" || parameter == "cache_control");
582        models.push(AvailableModel {
583            name: model.id.clone(),
584            display_name: model.name.or(Some(model.id)),
585            max_tokens: model.context_window.or(model.max_tokens).unwrap_or(128_000),
586            max_output_tokens: model.max_tokens,
587            max_completion_tokens: None,
588            capabilities: ModelCapabilities {
589                tools: supports_tools,
590                images: supports_images,
591                parallel_tool_calls,
592                prompt_cache_key,
593                chat_completions: true,
594            },
595        });
596    }
597
598    Ok(models)
599}
600
601struct ConfigurationView {
602    api_key_editor: Entity<InputField>,
603    state: Entity<State>,
604    load_credentials_task: Option<Task<()>>,
605}
606
607impl ConfigurationView {
608    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
609        let api_key_editor =
610            cx.new(|cx| InputField::new(window, cx, "vck_000000000000000000000000000"));
611
612        cx.observe(&state, |_, _, cx| cx.notify()).detach();
613
614        let load_credentials_task = Some(cx.spawn_in(window, {
615            let state = state.clone();
616            async move |this, cx| {
617                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
618                    let _ = task.await;
619                }
620                this.update(cx, |this, cx| {
621                    this.load_credentials_task = None;
622                    cx.notify();
623                })
624                .log_err();
625            }
626        }));
627
628        Self {
629            api_key_editor,
630            state,
631            load_credentials_task,
632        }
633    }
634
635    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
636        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
637        if api_key.is_empty() {
638            return;
639        }
640
641        self.api_key_editor
642            .update(cx, |editor, cx| editor.set_text("", window, cx));
643
644        let state = self.state.clone();
645        cx.spawn_in(window, async move |_, cx| {
646            state
647                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
648                .await
649        })
650        .detach_and_log_err(cx);
651    }
652
653    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
654        self.api_key_editor
655            .update(cx, |editor, cx| editor.set_text("", window, cx));
656
657        let state = self.state.clone();
658        cx.spawn_in(window, async move |_, cx| {
659            state
660                .update(cx, |state, cx| state.set_api_key(None, cx))
661                .await
662        })
663        .detach_and_log_err(cx);
664    }
665
666    fn should_render_editor(&self, cx: &Context<Self>) -> bool {
667        !self.state.read(cx).is_authenticated()
668    }
669}
670
671impl Render for ConfigurationView {
672    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
673        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
674        let configured_card_label = if env_var_set {
675            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
676        } else {
677            let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
678            if api_url == API_URL {
679                "API key configured".to_string()
680            } else {
681                format!("API key configured for {}", api_url)
682            }
683        };
684
685        if self.load_credentials_task.is_some() {
686            div().child(Label::new("Loading credentials...")).into_any()
687        } else if self.should_render_editor(cx) {
688            v_flex()
689                .size_full()
690                .on_action(cx.listener(Self::save_api_key))
691                .child(Label::new(
692                    "To use Zed's agent with Vercel AI Gateway, you need to add an API key. Follow these steps:",
693                ))
694                .child(
695                    List::new()
696                        .child(
697                            ListBulletItem::new("")
698                                .child(Label::new("Create an API key in"))
699                                .child(ButtonLink::new(
700                                    "Vercel AI Gateway's console",
701                                    "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys&title=Go+to+AI+Gateway",
702                                )),
703                        )
704                        .child(ListBulletItem::new(
705                            "Paste your API key below and hit enter to start using the assistant",
706                        )),
707                )
708                .child(self.api_key_editor.clone())
709                .child(
710                    Label::new(format!(
711                        "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.",
712                    ))
713                    .size(LabelSize::Small)
714                    .color(Color::Muted),
715                )
716                .into_any_element()
717        } else {
718            ConfiguredApiCard::new(configured_card_label)
719                .disabled(env_var_set)
720                .when(env_var_set, |this| {
721                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
722                })
723                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
724                .into_any_element()
725        }
726    }
727}