vercel_ai_gateway.rs

  1use anyhow::Result;
  2use collections::BTreeMap;
  3use futures::{AsyncReadExt, FutureExt, StreamExt, future::BoxFuture};
  4use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  6use language_model::{
  7    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  8    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 10    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
 11    env_var,
 12};
 13use open_ai::ResponseStreamEvent;
 14use serde::Deserialize;
 15pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
 16pub use settings::VercelAiGatewayAvailableModel as AvailableModel;
 17use settings::{Settings, SettingsStore};
 18use std::sync::{Arc, LazyLock};
 19use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 20use ui_input::InputField;
 21use util::ResultExt;
 22
 23const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("vercel_ai_gateway");
 24const PROVIDER_NAME: LanguageModelProviderName =
 25    LanguageModelProviderName::new("Vercel AI Gateway");
 26
 27const API_URL: &str = "https://ai-gateway.vercel.sh/v1";
 28const API_KEY_ENV_VAR_NAME: &str = "VERCEL_AI_GATEWAY_API_KEY";
 29static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 30
 31#[derive(Default, Clone, Debug, PartialEq)]
 32pub struct VercelAiGatewaySettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37pub struct VercelAiGatewayLanguageModelProvider {
 38    http_client: Arc<dyn HttpClient>,
 39    state: Entity<State>,
 40}
 41
 42pub struct State {
 43    api_key_state: ApiKeyState,
 44    http_client: Arc<dyn HttpClient>,
 45    available_models: Vec<AvailableModel>,
 46    fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
 47}
 48
 49impl State {
 50    fn is_authenticated(&self) -> bool {
 51        self.api_key_state.has_key()
 52    }
 53
 54    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 55        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 56        self.api_key_state
 57            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 58    }
 59
 60    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 61        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 62        let task = self
 63            .api_key_state
 64            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
 65
 66        cx.spawn(async move |this, cx| {
 67            let result = task.await;
 68            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
 69                .ok();
 70            result
 71        })
 72    }
 73
 74    fn fetch_models(
 75        &mut self,
 76        cx: &mut Context<Self>,
 77    ) -> Task<Result<(), LanguageModelCompletionError>> {
 78        let http_client = self.http_client.clone();
 79        let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
 80        let api_key = self.api_key_state.key(&api_url);
 81        cx.spawn(async move |this, cx| {
 82            let models = list_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
 83            this.update(cx, |this, cx| {
 84                this.available_models = models;
 85                cx.notify();
 86            })
 87            .map_err(|e| LanguageModelCompletionError::Other(e))?;
 88            Ok(())
 89        })
 90    }
 91
 92    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 93        if self.is_authenticated() {
 94            let task = self.fetch_models(cx);
 95            self.fetch_models_task.replace(task);
 96        } else {
 97            self.available_models = Vec::new();
 98        }
 99    }
100}
101
102impl VercelAiGatewayLanguageModelProvider {
103    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
104        let state = cx.new(|cx| {
105            cx.observe_global::<SettingsStore>({
106                let mut last_settings = VercelAiGatewayLanguageModelProvider::settings(cx).clone();
107                move |this: &mut State, cx| {
108                    let current_settings = VercelAiGatewayLanguageModelProvider::settings(cx);
109                    if current_settings != &last_settings {
110                        last_settings = current_settings.clone();
111                        this.authenticate(cx).detach();
112                        cx.notify();
113                    }
114                }
115            })
116            .detach();
117            State {
118                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
119                http_client: http_client.clone(),
120                available_models: Vec::new(),
121                fetch_models_task: None,
122            }
123        });
124
125        Self { http_client, state }
126    }
127
128    fn settings(cx: &App) -> &VercelAiGatewaySettings {
129        &crate::AllLanguageModelSettings::get_global(cx).vercel_ai_gateway
130    }
131
132    fn api_url(cx: &App) -> SharedString {
133        let api_url = &Self::settings(cx).api_url;
134        if api_url.is_empty() {
135            API_URL.into()
136        } else {
137            SharedString::new(api_url.as_str())
138        }
139    }
140
141    fn default_available_model() -> AvailableModel {
142        AvailableModel {
143            name: "openai/gpt-5.3-codex".to_string(),
144            display_name: Some("GPT 5.3 Codex".to_string()),
145            max_tokens: 400_000,
146            max_output_tokens: Some(128_000),
147            max_completion_tokens: None,
148            capabilities: ModelCapabilities::default(),
149        }
150    }
151
152    fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
153        Arc::new(VercelAiGatewayLanguageModel {
154            id: LanguageModelId::from(model.name.clone()),
155            model,
156            state: self.state.clone(),
157            http_client: self.http_client.clone(),
158            request_limiter: RateLimiter::new(4),
159        })
160    }
161}
162
163impl LanguageModelProviderState for VercelAiGatewayLanguageModelProvider {
164    type ObservableEntity = State;
165
166    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
167        Some(self.state.clone())
168    }
169}
170
171impl LanguageModelProvider for VercelAiGatewayLanguageModelProvider {
172    fn id(&self) -> LanguageModelProviderId {
173        PROVIDER_ID
174    }
175
176    fn name(&self) -> LanguageModelProviderName {
177        PROVIDER_NAME
178    }
179
180    fn icon(&self) -> IconOrSvg {
181        IconOrSvg::Icon(IconName::AiVercel)
182    }
183
184    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
185        Some(self.create_language_model(Self::default_available_model()))
186    }
187
188    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
189        None
190    }
191
192    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
193        let mut models = BTreeMap::default();
194
195        let default_model = Self::default_available_model();
196        models.insert(default_model.name.clone(), default_model);
197
198        for model in self.state.read(cx).available_models.clone() {
199            models.insert(model.name.clone(), model);
200        }
201
202        for model in &Self::settings(cx).available_models {
203            models.insert(model.name.clone(), model.clone());
204        }
205
206        models
207            .into_values()
208            .map(|model| self.create_language_model(model))
209            .collect()
210    }
211
212    fn is_authenticated(&self, cx: &App) -> bool {
213        self.state.read(cx).is_authenticated()
214    }
215
216    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
217        self.state.update(cx, |state, cx| state.authenticate(cx))
218    }
219
220    fn configuration_view(
221        &self,
222        _target_agent: language_model::ConfigurationViewTargetAgent,
223        window: &mut Window,
224        cx: &mut App,
225    ) -> AnyView {
226        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
227            .into()
228    }
229
230    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
231        self.state
232            .update(cx, |state, cx| state.set_api_key(None, cx))
233    }
234}
235
236pub struct VercelAiGatewayLanguageModel {
237    id: LanguageModelId,
238    model: AvailableModel,
239    state: Entity<State>,
240    http_client: Arc<dyn HttpClient>,
241    request_limiter: RateLimiter,
242}
243
244impl VercelAiGatewayLanguageModel {
245    fn stream_open_ai(
246        &self,
247        request: open_ai::Request,
248        cx: &AsyncApp,
249    ) -> BoxFuture<
250        'static,
251        Result<
252            futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
253            LanguageModelCompletionError,
254        >,
255    > {
256        let http_client = self.http_client.clone();
257        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
258            let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
259            (state.api_key_state.key(&api_url), api_url)
260        });
261
262        let future = self.request_limiter.stream(async move {
263            let provider = PROVIDER_NAME;
264            let Some(api_key) = api_key else {
265                return Err(LanguageModelCompletionError::NoApiKey { provider });
266            };
267            let request = open_ai::stream_completion(
268                http_client.as_ref(),
269                provider.0.as_str(),
270                &api_url,
271                &api_key,
272                request,
273            );
274            let response = request.await.map_err(map_open_ai_error)?;
275            Ok(response)
276        });
277
278        async move { Ok(future.await?.boxed()) }.boxed()
279    }
280}
281
282fn map_open_ai_error(error: open_ai::RequestError) -> LanguageModelCompletionError {
283    match error {
284        open_ai::RequestError::HttpResponseError {
285            status_code,
286            body,
287            headers,
288            ..
289        } => {
290            let retry_after = headers
291                .get(http::header::RETRY_AFTER)
292                .and_then(|value| value.to_str().ok()?.parse::<u64>().ok())
293                .map(std::time::Duration::from_secs);
294
295            LanguageModelCompletionError::from_http_status(
296                PROVIDER_NAME,
297                status_code,
298                extract_error_message(&body),
299                retry_after,
300            )
301        }
302        open_ai::RequestError::Other(error) => LanguageModelCompletionError::Other(error),
303    }
304}
305
306fn extract_error_message(body: &str) -> String {
307    let json = match serde_json::from_str::<serde_json::Value>(body) {
308        Ok(json) => json,
309        Err(_) => return body.to_string(),
310    };
311
312    let message = json
313        .get("error")
314        .and_then(|value| {
315            value
316                .get("message")
317                .and_then(serde_json::Value::as_str)
318                .or_else(|| value.as_str())
319        })
320        .or_else(|| json.get("message").and_then(serde_json::Value::as_str))
321        .map(ToString::to_string)
322        .unwrap_or_else(|| body.to_string());
323
324    clean_error_message(&message)
325}
326
327fn clean_error_message(message: &str) -> String {
328    let lower = message.to_lowercase();
329
330    if lower.contains("vercel_oidc_token") && lower.contains("oidc token") {
331        return "Authentication failed for Vercel AI Gateway. Use a Vercel AI Gateway key (vck_...).\nCreate or manage keys in Vercel AI Gateway console.\nIf this persists, regenerate the key and update it in Vercel AI Gateway provider settings in Zed.".to_string();
332    }
333
334    if lower.contains("invalid api key") || lower.contains("invalid_api_key") {
335        return "Authentication failed for Vercel AI Gateway. Check that your Vercel AI Gateway key starts with vck_ and is active.".to_string();
336    }
337
338    message.to_string()
339}
340
341fn has_tag(tags: &[String], expected: &str) -> bool {
342    tags.iter()
343        .any(|tag| tag.trim().eq_ignore_ascii_case(expected))
344}
345
346impl LanguageModel for VercelAiGatewayLanguageModel {
347    fn id(&self) -> LanguageModelId {
348        self.id.clone()
349    }
350
351    fn name(&self) -> LanguageModelName {
352        LanguageModelName::from(
353            self.model
354                .display_name
355                .clone()
356                .unwrap_or_else(|| self.model.name.clone()),
357        )
358    }
359
360    fn provider_id(&self) -> LanguageModelProviderId {
361        PROVIDER_ID
362    }
363
364    fn provider_name(&self) -> LanguageModelProviderName {
365        PROVIDER_NAME
366    }
367
368    fn supports_tools(&self) -> bool {
369        self.model.capabilities.tools
370    }
371
372    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
373        LanguageModelToolSchemaFormat::JsonSchemaSubset
374    }
375
376    fn supports_images(&self) -> bool {
377        self.model.capabilities.images
378    }
379
380    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
381        match choice {
382            LanguageModelToolChoice::Auto => self.model.capabilities.tools,
383            LanguageModelToolChoice::Any => self.model.capabilities.tools,
384            LanguageModelToolChoice::None => true,
385        }
386    }
387
388    fn supports_split_token_display(&self) -> bool {
389        true
390    }
391
392    fn telemetry_id(&self) -> String {
393        format!("vercel_ai_gateway/{}", self.model.name)
394    }
395
396    fn max_token_count(&self) -> u64 {
397        self.model.max_tokens
398    }
399
400    fn max_output_tokens(&self) -> Option<u64> {
401        self.model.max_output_tokens
402    }
403
404    fn count_tokens(
405        &self,
406        request: LanguageModelRequest,
407        cx: &App,
408    ) -> BoxFuture<'static, Result<u64>> {
409        let max_token_count = self.max_token_count();
410        cx.background_spawn(async move {
411            let messages = crate::provider::open_ai::collect_tiktoken_messages(request);
412            let model = if max_token_count >= 100_000 {
413                "gpt-4o"
414            } else {
415                "gpt-4"
416            };
417            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
418        })
419        .boxed()
420    }
421
422    fn stream_completion(
423        &self,
424        request: LanguageModelRequest,
425        cx: &AsyncApp,
426    ) -> BoxFuture<
427        'static,
428        Result<
429            futures::stream::BoxStream<
430                'static,
431                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
432            >,
433            LanguageModelCompletionError,
434        >,
435    > {
436        let request = crate::provider::open_ai::into_open_ai(
437            request,
438            &self.model.name,
439            self.model.capabilities.parallel_tool_calls,
440            self.model.capabilities.prompt_cache_key,
441            self.max_output_tokens(),
442            None,
443        );
444        let completions = self.stream_open_ai(request, cx);
445        async move {
446            let mapper = crate::provider::open_ai::OpenAiEventMapper::new();
447            Ok(mapper.map_stream(completions.await?).boxed())
448        }
449        .boxed()
450    }
451}
452
453#[derive(Deserialize)]
454struct ModelsResponse {
455    data: Vec<ApiModel>,
456}
457
458#[derive(Deserialize)]
459struct ApiModel {
460    id: String,
461    name: Option<String>,
462    context_window: Option<u64>,
463    max_tokens: Option<u64>,
464    #[serde(default)]
465    r#type: Option<String>,
466    #[serde(default)]
467    supported_parameters: Vec<String>,
468    #[serde(default)]
469    tags: Vec<String>,
470    architecture: Option<ApiModelArchitecture>,
471}
472
473#[derive(Deserialize)]
474struct ApiModelArchitecture {
475    #[serde(default)]
476    input_modalities: Vec<String>,
477}
478
479async fn list_models(
480    client: &dyn HttpClient,
481    api_url: &str,
482    api_key: Option<&str>,
483) -> Result<Vec<AvailableModel>, LanguageModelCompletionError> {
484    let uri = format!("{api_url}/models?include_mappings=true");
485    let mut request_builder = HttpRequest::builder()
486        .method(Method::GET)
487        .uri(uri)
488        .header("Accept", "application/json");
489    if let Some(api_key) = api_key {
490        request_builder = request_builder.header("Authorization", format!("Bearer {}", api_key));
491    }
492    let request = request_builder
493        .body(AsyncBody::default())
494        .map_err(|error| LanguageModelCompletionError::BuildRequestBody {
495            provider: PROVIDER_NAME,
496            error,
497        })?;
498    let mut response =
499        client
500            .send(request)
501            .await
502            .map_err(|error| LanguageModelCompletionError::HttpSend {
503                provider: PROVIDER_NAME,
504                error,
505            })?;
506
507    let mut body = String::new();
508    response
509        .body_mut()
510        .read_to_string(&mut body)
511        .await
512        .map_err(|error| LanguageModelCompletionError::ApiReadResponseError {
513            provider: PROVIDER_NAME,
514            error,
515        })?;
516
517    if !response.status().is_success() {
518        return Err(LanguageModelCompletionError::from_http_status(
519            PROVIDER_NAME,
520            response.status(),
521            extract_error_message(&body),
522            None,
523        ));
524    }
525
526    let response: ModelsResponse = serde_json::from_str(&body).map_err(|error| {
527        LanguageModelCompletionError::DeserializeResponse {
528            provider: PROVIDER_NAME,
529            error,
530        }
531    })?;
532
533    let mut models = Vec::new();
534    for model in response.data {
535        if let Some(model_type) = model.r#type.as_deref()
536            && model_type != "language"
537        {
538            continue;
539        }
540        let supports_tools = model
541            .supported_parameters
542            .iter()
543            .any(|parameter| parameter == "tools")
544            || has_tag(&model.tags, "tool-use")
545            || has_tag(&model.tags, "tools");
546        let supports_images = model.architecture.is_some_and(|architecture| {
547            architecture
548                .input_modalities
549                .iter()
550                .any(|modality| modality == "image")
551        }) || has_tag(&model.tags, "vision")
552            || has_tag(&model.tags, "image-input");
553        let parallel_tool_calls = model
554            .supported_parameters
555            .iter()
556            .any(|parameter| parameter == "parallel_tool_calls");
557        let prompt_cache_key = model
558            .supported_parameters
559            .iter()
560            .any(|parameter| parameter == "prompt_cache_key" || parameter == "cache_control");
561        models.push(AvailableModel {
562            name: model.id.clone(),
563            display_name: model.name.or(Some(model.id)),
564            max_tokens: model.context_window.or(model.max_tokens).unwrap_or(128_000),
565            max_output_tokens: model.max_tokens,
566            max_completion_tokens: None,
567            capabilities: ModelCapabilities {
568                tools: supports_tools,
569                images: supports_images,
570                parallel_tool_calls,
571                prompt_cache_key,
572                chat_completions: true,
573            },
574        });
575    }
576
577    Ok(models)
578}
579
580struct ConfigurationView {
581    api_key_editor: Entity<InputField>,
582    state: Entity<State>,
583    load_credentials_task: Option<Task<()>>,
584}
585
586impl ConfigurationView {
587    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
588        let api_key_editor =
589            cx.new(|cx| InputField::new(window, cx, "vck_000000000000000000000000000"));
590
591        cx.observe(&state, |_, _, cx| cx.notify()).detach();
592
593        let load_credentials_task = Some(cx.spawn_in(window, {
594            let state = state.clone();
595            async move |this, cx| {
596                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
597                    let _ = task.await;
598                }
599                this.update(cx, |this, cx| {
600                    this.load_credentials_task = None;
601                    cx.notify();
602                })
603                .log_err();
604            }
605        }));
606
607        Self {
608            api_key_editor,
609            state,
610            load_credentials_task,
611        }
612    }
613
614    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
615        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
616        if api_key.is_empty() {
617            return;
618        }
619
620        self.api_key_editor
621            .update(cx, |editor, cx| editor.set_text("", window, cx));
622
623        let state = self.state.clone();
624        cx.spawn_in(window, async move |_, cx| {
625            state
626                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
627                .await
628        })
629        .detach_and_log_err(cx);
630    }
631
632    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
633        self.api_key_editor
634            .update(cx, |editor, cx| editor.set_text("", window, cx));
635
636        let state = self.state.clone();
637        cx.spawn_in(window, async move |_, cx| {
638            state
639                .update(cx, |state, cx| state.set_api_key(None, cx))
640                .await
641        })
642        .detach_and_log_err(cx);
643    }
644
645    fn should_render_editor(&self, cx: &Context<Self>) -> bool {
646        !self.state.read(cx).is_authenticated()
647    }
648}
649
650impl Render for ConfigurationView {
651    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
652        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
653        let configured_card_label = if env_var_set {
654            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
655        } else {
656            let api_url = VercelAiGatewayLanguageModelProvider::api_url(cx);
657            if api_url == API_URL {
658                "API key configured".to_string()
659            } else {
660                format!("API key configured for {}", api_url)
661            }
662        };
663
664        if self.load_credentials_task.is_some() {
665            div().child(Label::new("Loading credentials...")).into_any()
666        } else if self.should_render_editor(cx) {
667            v_flex()
668                .size_full()
669                .on_action(cx.listener(Self::save_api_key))
670                .child(Label::new(
671                    "To use Zed's agent with Vercel AI Gateway, you need to add an API key. Follow these steps:",
672                ))
673                .child(
674                    List::new()
675                        .child(
676                            ListBulletItem::new("")
677                                .child(Label::new("Create an API key in"))
678                                .child(ButtonLink::new(
679                                    "Vercel AI Gateway's console",
680                                    "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys&title=Go+to+AI+Gateway",
681                                )),
682                        )
683                        .child(ListBulletItem::new(
684                            "Paste your API key below and hit enter to start using the assistant",
685                        )),
686                )
687                .child(self.api_key_editor.clone())
688                .child(
689                    Label::new(format!(
690                        "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.",
691                    ))
692                    .size(LabelSize::Small)
693                    .color(Color::Muted),
694                )
695                .into_any_element()
696        } else {
697            ConfiguredApiCard::new(configured_card_label)
698                .disabled(env_var_set)
699                .when(env_var_set, |this| {
700                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
701                })
702                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
703                .into_any_element()
704        }
705    }
706}