open_ai.rs

  1use anyhow::Result;
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use futures::{FutureExt, StreamExt, future::BoxFuture};
  5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  6use http_client::HttpClient;
  7use language_model::{
  8    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  9    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 11    LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME,
 12    RateLimiter, env_var,
 13};
 14use menu;
 15use open_ai::{
 16    OPEN_AI_API_URL, ResponseStreamEvent,
 17    responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response},
 18    stream_completion,
 19};
 20use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
 21use std::sync::{Arc, LazyLock};
 22use strum::IntoEnumIterator;
 23use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 24use ui_input::InputField;
 25use util::ResultExt;
 26
 27pub use open_ai::completion::{
 28    OpenAiEventMapper, OpenAiResponseEventMapper, collect_tiktoken_messages, count_open_ai_tokens,
 29    into_open_ai, into_open_ai_response,
 30};
 31
 32const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID;
 33const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME;
 34
 35const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
 36static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct OpenAiSettings {
 40    pub api_url: String,
 41    pub available_models: Vec<AvailableModel>,
 42}
 43
 44pub struct OpenAiLanguageModelProvider {
 45    http_client: Arc<dyn HttpClient>,
 46    state: Entity<State>,
 47}
 48
 49pub struct State {
 50    api_key_state: ApiKeyState,
 51    credentials_provider: Arc<dyn CredentialsProvider>,
 52}
 53
 54impl State {
 55    fn is_authenticated(&self) -> bool {
 56        self.api_key_state.has_key()
 57    }
 58
 59    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 60        let credentials_provider = self.credentials_provider.clone();
 61        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 62        self.api_key_state.store(
 63            api_url,
 64            api_key,
 65            |this| &mut this.api_key_state,
 66            credentials_provider,
 67            cx,
 68        )
 69    }
 70
 71    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 72        let credentials_provider = self.credentials_provider.clone();
 73        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 74        self.api_key_state.load_if_needed(
 75            api_url,
 76            |this| &mut this.api_key_state,
 77            credentials_provider,
 78            cx,
 79        )
 80    }
 81}
 82
 83impl OpenAiLanguageModelProvider {
 84    pub fn new(
 85        http_client: Arc<dyn HttpClient>,
 86        credentials_provider: Arc<dyn CredentialsProvider>,
 87        cx: &mut App,
 88    ) -> Self {
 89        let state = cx.new(|cx| {
 90            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 91                let credentials_provider = this.credentials_provider.clone();
 92                let api_url = Self::api_url(cx);
 93                this.api_key_state.handle_url_change(
 94                    api_url,
 95                    |this| &mut this.api_key_state,
 96                    credentials_provider,
 97                    cx,
 98                );
 99                cx.notify();
100            })
101            .detach();
102            State {
103                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
104                credentials_provider,
105            }
106        });
107
108        Self { http_client, state }
109    }
110
111    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
112        Arc::new(OpenAiLanguageModel {
113            id: LanguageModelId::from(model.id().to_string()),
114            model,
115            state: self.state.clone(),
116            http_client: self.http_client.clone(),
117            request_limiter: RateLimiter::new(4),
118        })
119    }
120
121    fn settings(cx: &App) -> &OpenAiSettings {
122        &crate::AllLanguageModelSettings::get_global(cx).openai
123    }
124
125    fn api_url(cx: &App) -> SharedString {
126        let api_url = &Self::settings(cx).api_url;
127        if api_url.is_empty() {
128            open_ai::OPEN_AI_API_URL.into()
129        } else {
130            SharedString::new(api_url.as_str())
131        }
132    }
133}
134
135impl LanguageModelProviderState for OpenAiLanguageModelProvider {
136    type ObservableEntity = State;
137
138    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
139        Some(self.state.clone())
140    }
141}
142
143impl LanguageModelProvider for OpenAiLanguageModelProvider {
144    fn id(&self) -> LanguageModelProviderId {
145        PROVIDER_ID
146    }
147
148    fn name(&self) -> LanguageModelProviderName {
149        PROVIDER_NAME
150    }
151
152    fn icon(&self) -> IconOrSvg {
153        IconOrSvg::Icon(IconName::AiOpenAi)
154    }
155
156    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        Some(self.create_language_model(open_ai::Model::default()))
158    }
159
160    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
161        Some(self.create_language_model(open_ai::Model::default_fast()))
162    }
163
164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        let mut models = BTreeMap::default();
166
167        // Add base models from open_ai::Model::iter()
168        for model in open_ai::Model::iter() {
169            if !matches!(model, open_ai::Model::Custom { .. }) {
170                models.insert(model.id().to_string(), model);
171            }
172        }
173
174        // Override with available models from settings
175        for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
176            models.insert(
177                model.name.clone(),
178                open_ai::Model::Custom {
179                    name: model.name.clone(),
180                    display_name: model.display_name.clone(),
181                    max_tokens: model.max_tokens,
182                    max_output_tokens: model.max_output_tokens,
183                    max_completion_tokens: model.max_completion_tokens,
184                    reasoning_effort: model.reasoning_effort,
185                    supports_chat_completions: model.capabilities.chat_completions,
186                },
187            );
188        }
189
190        models
191            .into_values()
192            .map(|model| self.create_language_model(model))
193            .collect()
194    }
195
196    fn is_authenticated(&self, cx: &App) -> bool {
197        self.state.read(cx).is_authenticated()
198    }
199
200    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
201        self.state.update(cx, |state, cx| state.authenticate(cx))
202    }
203
204    fn configuration_view(
205        &self,
206        _target_agent: language_model::ConfigurationViewTargetAgent,
207        window: &mut Window,
208        cx: &mut App,
209    ) -> AnyView {
210        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
211            .into()
212    }
213
214    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
215        self.state
216            .update(cx, |state, cx| state.set_api_key(None, cx))
217    }
218}
219
220pub struct OpenAiLanguageModel {
221    id: LanguageModelId,
222    model: open_ai::Model,
223    state: Entity<State>,
224    http_client: Arc<dyn HttpClient>,
225    request_limiter: RateLimiter,
226}
227
228impl OpenAiLanguageModel {
229    fn stream_completion(
230        &self,
231        request: open_ai::Request,
232        cx: &AsyncApp,
233    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
234    {
235        let http_client = self.http_client.clone();
236
237        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
238            let api_url = OpenAiLanguageModelProvider::api_url(cx);
239            (state.api_key_state.key(&api_url), api_url)
240        });
241
242        let future = self.request_limiter.stream(async move {
243            let provider = PROVIDER_NAME;
244            let Some(api_key) = api_key else {
245                return Err(LanguageModelCompletionError::NoApiKey { provider });
246            };
247            let request = stream_completion(
248                http_client.as_ref(),
249                provider.0.as_str(),
250                &api_url,
251                &api_key,
252                request,
253            );
254            let response = request.await?;
255            Ok(response)
256        });
257
258        async move { Ok(future.await?.boxed()) }.boxed()
259    }
260
261    fn stream_response(
262        &self,
263        request: ResponseRequest,
264        cx: &AsyncApp,
265    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
266    {
267        let http_client = self.http_client.clone();
268
269        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
270            let api_url = OpenAiLanguageModelProvider::api_url(cx);
271            (state.api_key_state.key(&api_url), api_url)
272        });
273
274        let provider = PROVIDER_NAME;
275        let future = self.request_limiter.stream(async move {
276            let Some(api_key) = api_key else {
277                return Err(LanguageModelCompletionError::NoApiKey { provider });
278            };
279            let request = stream_response(
280                http_client.as_ref(),
281                provider.0.as_str(),
282                &api_url,
283                &api_key,
284                request,
285            );
286            let response = request.await?;
287            Ok(response)
288        });
289
290        async move { Ok(future.await?.boxed()) }.boxed()
291    }
292}
293
294impl LanguageModel for OpenAiLanguageModel {
295    fn id(&self) -> LanguageModelId {
296        self.id.clone()
297    }
298
299    fn name(&self) -> LanguageModelName {
300        LanguageModelName::from(self.model.display_name().to_string())
301    }
302
303    fn provider_id(&self) -> LanguageModelProviderId {
304        PROVIDER_ID
305    }
306
307    fn provider_name(&self) -> LanguageModelProviderName {
308        PROVIDER_NAME
309    }
310
311    fn supports_tools(&self) -> bool {
312        true
313    }
314
315    fn supports_images(&self) -> bool {
316        use open_ai::Model;
317        match &self.model {
318            Model::FourOmniMini
319            | Model::FourPointOneNano
320            | Model::Five
321            | Model::FiveCodex
322            | Model::FiveMini
323            | Model::FiveNano
324            | Model::FivePointOne
325            | Model::FivePointTwo
326            | Model::FivePointTwoCodex
327            | Model::FivePointThreeCodex
328            | Model::FivePointFour
329            | Model::FivePointFourPro
330            | Model::O1
331            | Model::O3 => true,
332            Model::ThreePointFiveTurbo
333            | Model::Four
334            | Model::FourTurbo
335            | Model::O3Mini
336            | Model::Custom { .. } => false,
337        }
338    }
339
340    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
341        match choice {
342            LanguageModelToolChoice::Auto => true,
343            LanguageModelToolChoice::Any => true,
344            LanguageModelToolChoice::None => true,
345        }
346    }
347
348    fn supports_streaming_tools(&self) -> bool {
349        true
350    }
351
352    fn supports_thinking(&self) -> bool {
353        self.model.reasoning_effort().is_some()
354    }
355
356    fn supports_split_token_display(&self) -> bool {
357        true
358    }
359
360    fn telemetry_id(&self) -> String {
361        format!("openai/{}", self.model.id())
362    }
363
364    fn max_token_count(&self) -> u64 {
365        self.model.max_token_count()
366    }
367
368    fn max_output_tokens(&self) -> Option<u64> {
369        self.model.max_output_tokens()
370    }
371
372    fn count_tokens(
373        &self,
374        request: LanguageModelRequest,
375        cx: &App,
376    ) -> BoxFuture<'static, Result<u64>> {
377        let model = self.model.clone();
378        cx.background_spawn(async move { count_open_ai_tokens(request, model) })
379            .boxed()
380    }
381
382    fn stream_completion(
383        &self,
384        request: LanguageModelRequest,
385        cx: &AsyncApp,
386    ) -> BoxFuture<
387        'static,
388        Result<
389            futures::stream::BoxStream<
390                'static,
391                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
392            >,
393            LanguageModelCompletionError,
394        >,
395    > {
396        if self.model.supports_chat_completions() {
397            let request = into_open_ai(
398                request,
399                self.model.id(),
400                self.model.supports_parallel_tool_calls(),
401                self.model.supports_prompt_cache_key(),
402                self.max_output_tokens(),
403                self.model.reasoning_effort(),
404            );
405            let completions = self.stream_completion(request, cx);
406            async move {
407                let mapper = OpenAiEventMapper::new();
408                Ok(mapper.map_stream(completions.await?).boxed())
409            }
410            .boxed()
411        } else {
412            let request = into_open_ai_response(
413                request,
414                self.model.id(),
415                self.model.supports_parallel_tool_calls(),
416                self.model.supports_prompt_cache_key(),
417                self.max_output_tokens(),
418                self.model.reasoning_effort(),
419            );
420            let completions = self.stream_response(request, cx);
421            async move {
422                let mapper = OpenAiResponseEventMapper::new();
423                Ok(mapper.map_stream(completions.await?).boxed())
424            }
425            .boxed()
426        }
427    }
428}
429
430struct ConfigurationView {
431    api_key_editor: Entity<InputField>,
432    state: Entity<State>,
433    load_credentials_task: Option<Task<()>>,
434}
435
436impl ConfigurationView {
437    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
438        let api_key_editor = cx.new(|cx| {
439            InputField::new(
440                window,
441                cx,
442                "sk-000000000000000000000000000000000000000000000000",
443            )
444        });
445
446        cx.observe(&state, |_, _, cx| {
447            cx.notify();
448        })
449        .detach();
450
451        let load_credentials_task = Some(cx.spawn_in(window, {
452            let state = state.clone();
453            async move |this, cx| {
454                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
455                    // We don't log an error, because "not signed in" is also an error.
456                    let _ = task.await;
457                }
458                this.update(cx, |this, cx| {
459                    this.load_credentials_task = None;
460                    cx.notify();
461                })
462                .log_err();
463            }
464        }));
465
466        Self {
467            api_key_editor,
468            state,
469            load_credentials_task,
470        }
471    }
472
473    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
474        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
475        if api_key.is_empty() {
476            return;
477        }
478
479        // url changes can cause the editor to be displayed again
480        self.api_key_editor
481            .update(cx, |editor, cx| editor.set_text("", window, cx));
482
483        let state = self.state.clone();
484        cx.spawn_in(window, async move |_, cx| {
485            state
486                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
487                .await
488        })
489        .detach_and_log_err(cx);
490    }
491
492    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
493        self.api_key_editor
494            .update(cx, |input, cx| input.set_text("", window, cx));
495
496        let state = self.state.clone();
497        cx.spawn_in(window, async move |_, cx| {
498            state
499                .update(cx, |state, cx| state.set_api_key(None, cx))
500                .await
501        })
502        .detach_and_log_err(cx);
503    }
504
505    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
506        !self.state.read(cx).is_authenticated()
507    }
508}
509
510impl Render for ConfigurationView {
511    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
512        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
513        let configured_card_label = if env_var_set {
514            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
515        } else {
516            let api_url = OpenAiLanguageModelProvider::api_url(cx);
517            if api_url == OPEN_AI_API_URL {
518                "API key configured".to_string()
519            } else {
520                format!("API key configured for {}", api_url)
521            }
522        };
523
524        let api_key_section = if self.should_render_editor(cx) {
525            v_flex()
526                .on_action(cx.listener(Self::save_api_key))
527                .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
528                .child(
529                    List::new()
530                        .child(
531                            ListBulletItem::new("")
532                                .child(Label::new("Create one by visiting"))
533                                .child(ButtonLink::new("OpenAI's console", "https://platform.openai.com/api-keys"))
534                        )
535                        .child(
536                            ListBulletItem::new("Ensure your OpenAI account has credits")
537                        )
538                        .child(
539                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
540                        ),
541                )
542                .child(self.api_key_editor.clone())
543                .child(
544                    Label::new(format!(
545                        "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
546                    ))
547                    .size(LabelSize::Small)
548                    .color(Color::Muted),
549                )
550                .child(
551                    Label::new(
552                        "Note that having a subscription for another service like GitHub Copilot won't work.",
553                    )
554                    .size(LabelSize::Small).color(Color::Muted),
555                )
556                .into_any_element()
557        } else {
558            ConfiguredApiCard::new(configured_card_label)
559                .disabled(env_var_set)
560                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
561                .when(env_var_set, |this| {
562                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
563                })
564                .into_any_element()
565        };
566
567        let compatible_api_section = h_flex()
568            .mt_1p5()
569            .gap_0p5()
570            .flex_wrap()
571            .when(self.should_render_editor(cx), |this| {
572                this.pt_1p5()
573                    .border_t_1()
574                    .border_color(cx.theme().colors().border_variant)
575            })
576            .child(
577                h_flex()
578                    .gap_2()
579                    .child(
580                        Icon::new(IconName::Info)
581                            .size(IconSize::XSmall)
582                            .color(Color::Muted),
583                    )
584                    .child(Label::new("Zed also supports OpenAI-compatible models.")),
585            )
586            .child(
587                Button::new("docs", "Learn More")
588                    .end_icon(
589                        Icon::new(IconName::ArrowUpRight)
590                            .size(IconSize::Small)
591                            .color(Color::Muted),
592                    )
593                    .on_click(move |_, _window, cx| {
594                        cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
595                    }),
596            );
597
598        if self.load_credentials_task.is_some() {
599            div().child(Label::new("Loading credentials…")).into_any()
600        } else {
601            v_flex()
602                .size_full()
603                .child(api_key_section)
604                .child(compatible_api_section)
605                .into_any()
606        }
607    }
608}