open_ai.rs

  1use anyhow::Result;
  2use collections::BTreeMap;
  3use credentials_provider::CredentialsProvider;
  4use futures::{FutureExt, StreamExt, future::BoxFuture};
  5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  6use http_client::HttpClient;
  7use language_model::{
  8    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  9    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 11    LanguageModelRequest, LanguageModelToolChoice, OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME,
 12    RateLimiter, env_var,
 13};
 14use menu;
 15use open_ai::{
 16    OPEN_AI_API_URL, ResponseStreamEvent,
 17    responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response},
 18    stream_completion,
 19};
 20use settings::{OpenAiAvailableModel as AvailableModel, Settings, SettingsStore};
 21use std::sync::{Arc, LazyLock};
 22use strum::IntoEnumIterator;
 23use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
 24use ui_input::InputField;
 25use util::ResultExt;
 26
 27pub use open_ai::completion::{
 28    OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
 29};
 30
 31const PROVIDER_ID: LanguageModelProviderId = OPEN_AI_PROVIDER_ID;
 32const PROVIDER_NAME: LanguageModelProviderName = OPEN_AI_PROVIDER_NAME;
 33
 34const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
 35static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 36
 37#[derive(Default, Clone, Debug, PartialEq)]
 38pub struct OpenAiSettings {
 39    pub api_url: String,
 40    pub available_models: Vec<AvailableModel>,
 41}
 42
 43pub struct OpenAiLanguageModelProvider {
 44    http_client: Arc<dyn HttpClient>,
 45    state: Entity<State>,
 46}
 47
 48pub struct State {
 49    api_key_state: ApiKeyState,
 50    credentials_provider: Arc<dyn CredentialsProvider>,
 51}
 52
 53impl State {
 54    fn is_authenticated(&self) -> bool {
 55        self.api_key_state.has_key()
 56    }
 57
 58    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 59        let credentials_provider = self.credentials_provider.clone();
 60        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 61        self.api_key_state.store(
 62            api_url,
 63            api_key,
 64            |this| &mut this.api_key_state,
 65            credentials_provider,
 66            cx,
 67        )
 68    }
 69
 70    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 71        let credentials_provider = self.credentials_provider.clone();
 72        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 73        self.api_key_state.load_if_needed(
 74            api_url,
 75            |this| &mut this.api_key_state,
 76            credentials_provider,
 77            cx,
 78        )
 79    }
 80}
 81
 82impl OpenAiLanguageModelProvider {
 83    pub fn new(
 84        http_client: Arc<dyn HttpClient>,
 85        credentials_provider: Arc<dyn CredentialsProvider>,
 86        cx: &mut App,
 87    ) -> Self {
 88        let state = cx.new(|cx| {
 89            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 90                let credentials_provider = this.credentials_provider.clone();
 91                let api_url = Self::api_url(cx);
 92                this.api_key_state.handle_url_change(
 93                    api_url,
 94                    |this| &mut this.api_key_state,
 95                    credentials_provider,
 96                    cx,
 97                );
 98                cx.notify();
 99            })
100            .detach();
101            State {
102                api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
103                credentials_provider,
104            }
105        });
106
107        Self { http_client, state }
108    }
109
110    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
111        Arc::new(OpenAiLanguageModel {
112            id: LanguageModelId::from(model.id().to_string()),
113            model,
114            state: self.state.clone(),
115            http_client: self.http_client.clone(),
116            request_limiter: RateLimiter::new(4),
117        })
118    }
119
120    fn settings(cx: &App) -> &OpenAiSettings {
121        &crate::AllLanguageModelSettings::get_global(cx).openai
122    }
123
124    fn api_url(cx: &App) -> SharedString {
125        let api_url = &Self::settings(cx).api_url;
126        if api_url.is_empty() {
127            open_ai::OPEN_AI_API_URL.into()
128        } else {
129            SharedString::new(api_url.as_str())
130        }
131    }
132}
133
134impl LanguageModelProviderState for OpenAiLanguageModelProvider {
135    type ObservableEntity = State;
136
137    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
138        Some(self.state.clone())
139    }
140}
141
142impl LanguageModelProvider for OpenAiLanguageModelProvider {
143    fn id(&self) -> LanguageModelProviderId {
144        PROVIDER_ID
145    }
146
147    fn name(&self) -> LanguageModelProviderName {
148        PROVIDER_NAME
149    }
150
151    fn icon(&self) -> IconOrSvg {
152        IconOrSvg::Icon(IconName::AiOpenAi)
153    }
154
155    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
156        Some(self.create_language_model(open_ai::Model::default()))
157    }
158
159    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
160        Some(self.create_language_model(open_ai::Model::default_fast()))
161    }
162
163    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
164        let mut models = BTreeMap::default();
165
166        // Add base models from open_ai::Model::iter()
167        for model in open_ai::Model::iter() {
168            if !matches!(model, open_ai::Model::Custom { .. }) {
169                models.insert(model.id().to_string(), model);
170            }
171        }
172
173        // Override with available models from settings
174        for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
175            models.insert(
176                model.name.clone(),
177                open_ai::Model::Custom {
178                    name: model.name.clone(),
179                    display_name: model.display_name.clone(),
180                    max_tokens: model.max_tokens,
181                    max_output_tokens: model.max_output_tokens,
182                    max_completion_tokens: model.max_completion_tokens,
183                    reasoning_effort: model.reasoning_effort,
184                    supports_chat_completions: model.capabilities.chat_completions,
185                },
186            );
187        }
188
189        models
190            .into_values()
191            .map(|model| self.create_language_model(model))
192            .collect()
193    }
194
195    fn is_authenticated(&self, cx: &App) -> bool {
196        self.state.read(cx).is_authenticated()
197    }
198
199    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
200        self.state.update(cx, |state, cx| state.authenticate(cx))
201    }
202
203    fn configuration_view(
204        &self,
205        _target_agent: language_model::ConfigurationViewTargetAgent,
206        window: &mut Window,
207        cx: &mut App,
208    ) -> AnyView {
209        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
210            .into()
211    }
212
213    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
214        self.state
215            .update(cx, |state, cx| state.set_api_key(None, cx))
216    }
217}
218
219pub struct OpenAiLanguageModel {
220    id: LanguageModelId,
221    model: open_ai::Model,
222    state: Entity<State>,
223    http_client: Arc<dyn HttpClient>,
224    request_limiter: RateLimiter,
225}
226
227impl OpenAiLanguageModel {
228    fn stream_completion(
229        &self,
230        request: open_ai::Request,
231        cx: &AsyncApp,
232    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
233    {
234        let http_client = self.http_client.clone();
235
236        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
237            let api_url = OpenAiLanguageModelProvider::api_url(cx);
238            (state.api_key_state.key(&api_url), api_url)
239        });
240
241        let future = self.request_limiter.stream(async move {
242            let provider = PROVIDER_NAME;
243            let Some(api_key) = api_key else {
244                return Err(LanguageModelCompletionError::NoApiKey { provider });
245            };
246            let request = stream_completion(
247                http_client.as_ref(),
248                provider.0.as_str(),
249                &api_url,
250                &api_key,
251                request,
252            );
253            let response = request.await?;
254            Ok(response)
255        });
256
257        async move { Ok(future.await?.boxed()) }.boxed()
258    }
259
260    fn stream_response(
261        &self,
262        request: ResponseRequest,
263        cx: &AsyncApp,
264    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
265    {
266        let http_client = self.http_client.clone();
267
268        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
269            let api_url = OpenAiLanguageModelProvider::api_url(cx);
270            (state.api_key_state.key(&api_url), api_url)
271        });
272
273        let provider = PROVIDER_NAME;
274        let future = self.request_limiter.stream(async move {
275            let Some(api_key) = api_key else {
276                return Err(LanguageModelCompletionError::NoApiKey { provider });
277            };
278            let request = stream_response(
279                http_client.as_ref(),
280                provider.0.as_str(),
281                &api_url,
282                &api_key,
283                request,
284            );
285            let response = request.await?;
286            Ok(response)
287        });
288
289        async move { Ok(future.await?.boxed()) }.boxed()
290    }
291}
292
293impl LanguageModel for OpenAiLanguageModel {
294    fn id(&self) -> LanguageModelId {
295        self.id.clone()
296    }
297
298    fn name(&self) -> LanguageModelName {
299        LanguageModelName::from(self.model.display_name().to_string())
300    }
301
302    fn provider_id(&self) -> LanguageModelProviderId {
303        PROVIDER_ID
304    }
305
306    fn provider_name(&self) -> LanguageModelProviderName {
307        PROVIDER_NAME
308    }
309
310    fn supports_tools(&self) -> bool {
311        true
312    }
313
314    fn supports_images(&self) -> bool {
315        use open_ai::Model;
316        match &self.model {
317            Model::FourOmniMini
318            | Model::FourPointOneNano
319            | Model::Five
320            | Model::FiveCodex
321            | Model::FiveMini
322            | Model::FiveNano
323            | Model::FivePointOne
324            | Model::FivePointTwo
325            | Model::FivePointTwoCodex
326            | Model::FivePointThreeCodex
327            | Model::FivePointFour
328            | Model::FivePointFourPro
329            | Model::O1
330            | Model::O3 => true,
331            Model::ThreePointFiveTurbo
332            | Model::Four
333            | Model::FourTurbo
334            | Model::O3Mini
335            | Model::Custom { .. } => false,
336        }
337    }
338
339    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
340        match choice {
341            LanguageModelToolChoice::Auto => true,
342            LanguageModelToolChoice::Any => true,
343            LanguageModelToolChoice::None => true,
344        }
345    }
346
347    fn supports_streaming_tools(&self) -> bool {
348        true
349    }
350
351    fn supports_thinking(&self) -> bool {
352        self.model.reasoning_effort().is_some()
353    }
354
355    fn supports_split_token_display(&self) -> bool {
356        true
357    }
358
359    fn telemetry_id(&self) -> String {
360        format!("openai/{}", self.model.id())
361    }
362
363    fn max_token_count(&self) -> u64 {
364        self.model.max_token_count()
365    }
366
367    fn max_output_tokens(&self) -> Option<u64> {
368        self.model.max_output_tokens()
369    }
370
371    fn stream_completion(
372        &self,
373        request: LanguageModelRequest,
374        cx: &AsyncApp,
375    ) -> BoxFuture<
376        'static,
377        Result<
378            futures::stream::BoxStream<
379                'static,
380                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
381            >,
382            LanguageModelCompletionError,
383        >,
384    > {
385        if self.model.supports_chat_completions() {
386            let request = into_open_ai(
387                request,
388                self.model.id(),
389                self.model.supports_parallel_tool_calls(),
390                self.model.supports_prompt_cache_key(),
391                self.max_output_tokens(),
392                self.model.reasoning_effort(),
393                false,
394            );
395            let completions = self.stream_completion(request, cx);
396            async move {
397                let mapper = OpenAiEventMapper::new();
398                Ok(mapper.map_stream(completions.await?).boxed())
399            }
400            .boxed()
401        } else {
402            let request = into_open_ai_response(
403                request,
404                self.model.id(),
405                self.model.supports_parallel_tool_calls(),
406                self.model.supports_prompt_cache_key(),
407                self.max_output_tokens(),
408                self.model.reasoning_effort(),
409            );
410            let completions = self.stream_response(request, cx);
411            async move {
412                let mapper = OpenAiResponseEventMapper::new();
413                Ok(mapper.map_stream(completions.await?).boxed())
414            }
415            .boxed()
416        }
417    }
418}
419
420struct ConfigurationView {
421    api_key_editor: Entity<InputField>,
422    state: Entity<State>,
423    load_credentials_task: Option<Task<()>>,
424}
425
426impl ConfigurationView {
427    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
428        let api_key_editor = cx.new(|cx| {
429            InputField::new(
430                window,
431                cx,
432                "sk-000000000000000000000000000000000000000000000000",
433            )
434        });
435
436        cx.observe(&state, |_, _, cx| {
437            cx.notify();
438        })
439        .detach();
440
441        let load_credentials_task = Some(cx.spawn_in(window, {
442            let state = state.clone();
443            async move |this, cx| {
444                if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
445                    // We don't log an error, because "not signed in" is also an error.
446                    let _ = task.await;
447                }
448                this.update(cx, |this, cx| {
449                    this.load_credentials_task = None;
450                    cx.notify();
451                })
452                .log_err();
453            }
454        }));
455
456        Self {
457            api_key_editor,
458            state,
459            load_credentials_task,
460        }
461    }
462
463    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
464        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
465        if api_key.is_empty() {
466            return;
467        }
468
469        // url changes can cause the editor to be displayed again
470        self.api_key_editor
471            .update(cx, |editor, cx| editor.set_text("", window, cx));
472
473        let state = self.state.clone();
474        cx.spawn_in(window, async move |_, cx| {
475            state
476                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
477                .await
478        })
479        .detach_and_log_err(cx);
480    }
481
482    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
483        self.api_key_editor
484            .update(cx, |input, cx| input.set_text("", window, cx));
485
486        let state = self.state.clone();
487        cx.spawn_in(window, async move |_, cx| {
488            state
489                .update(cx, |state, cx| state.set_api_key(None, cx))
490                .await
491        })
492        .detach_and_log_err(cx);
493    }
494
495    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
496        !self.state.read(cx).is_authenticated()
497    }
498}
499
500impl Render for ConfigurationView {
501    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
502        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
503        let configured_card_label = if env_var_set {
504            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
505        } else {
506            let api_url = OpenAiLanguageModelProvider::api_url(cx);
507            if api_url == OPEN_AI_API_URL {
508                "API key configured".to_string()
509            } else {
510                format!("API key configured for {}", api_url)
511            }
512        };
513
514        let api_key_section = if self.should_render_editor(cx) {
515            v_flex()
516                .on_action(cx.listener(Self::save_api_key))
517                .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
518                .child(
519                    List::new()
520                        .child(
521                            ListBulletItem::new("")
522                                .child(Label::new("Create one by visiting"))
523                                .child(ButtonLink::new("OpenAI's console", "https://platform.openai.com/api-keys"))
524                        )
525                        .child(
526                            ListBulletItem::new("Ensure your OpenAI account has credits")
527                        )
528                        .child(
529                            ListBulletItem::new("Paste your API key below and hit enter to start using the agent")
530                        ),
531                )
532                .child(self.api_key_editor.clone())
533                .child(
534                    Label::new(format!(
535                        "You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
536                    ))
537                    .size(LabelSize::Small)
538                    .color(Color::Muted),
539                )
540                .child(
541                    Label::new(
542                        "Note that having a subscription for another service like GitHub Copilot won't work.",
543                    )
544                    .size(LabelSize::Small).color(Color::Muted),
545                )
546                .into_any_element()
547        } else {
548            ConfiguredApiCard::new(configured_card_label)
549                .disabled(env_var_set)
550                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
551                .when(env_var_set, |this| {
552                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
553                })
554                .into_any_element()
555        };
556
557        let compatible_api_section = h_flex()
558            .mt_1p5()
559            .gap_0p5()
560            .flex_wrap()
561            .when(self.should_render_editor(cx), |this| {
562                this.pt_1p5()
563                    .border_t_1()
564                    .border_color(cx.theme().colors().border_variant)
565            })
566            .child(
567                h_flex()
568                    .gap_2()
569                    .child(
570                        Icon::new(IconName::Info)
571                            .size(IconSize::XSmall)
572                            .color(Color::Muted),
573                    )
574                    .child(Label::new("Zed also supports OpenAI-compatible models.")),
575            )
576            .child(
577                Button::new("docs", "Learn More")
578                    .end_icon(
579                        Icon::new(IconName::ArrowUpRight)
580                            .size(IconSize::Small)
581                            .color(Color::Muted),
582                    )
583                    .on_click(move |_, _window, cx| {
584                        cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
585                    }),
586            );
587
588        if self.load_credentials_task.is_some() {
589            div().child(Label::new("Loading credentials…")).into_any()
590        } else {
591            v_flex()
592                .size_full()
593                .child(api_key_section)
594                .child(compatible_api_section)
595                .into_any()
596        }
597    }
598}