open_ai.rs

  1use anyhow::{anyhow, Context as _, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
  7};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
 11    LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 12    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 13};
 14use open_ai::{
 15    stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
 16};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use settings::{Settings, SettingsStore};
 20use std::sync::Arc;
 21use strum::IntoEnumIterator;
 22use theme::ThemeSettings;
 23use ui::{prelude::*, Icon, IconName, Tooltip};
 24use util::ResultExt;
 25
 26use crate::AllLanguageModelSettings;
 27
 28const PROVIDER_ID: &str = "openai";
 29const PROVIDER_NAME: &str = "OpenAI";
 30
 31#[derive(Default, Clone, Debug, PartialEq)]
 32pub struct OpenAiSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35    pub needs_setting_migration: bool,
 36}
 37
 38#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 39pub struct AvailableModel {
 40    pub name: String,
 41    pub display_name: Option<String>,
 42    pub max_tokens: usize,
 43    pub max_output_tokens: Option<u32>,
 44    pub max_completion_tokens: Option<u32>,
 45}
 46
 47pub struct OpenAiLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Entity<State>,
 50}
 51
 52pub struct State {
 53    api_key: Option<String>,
 54    api_key_from_env: bool,
 55    _subscription: Subscription,
 56}
 57
 58const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
 59
 60impl State {
 61    fn is_authenticated(&self) -> bool {
 62        self.api_key.is_some()
 63    }
 64
 65    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 66        let settings = &AllLanguageModelSettings::get_global(cx).openai;
 67        let delete_credentials = cx.delete_credentials(&settings.api_url);
 68        cx.spawn(|this, mut cx| async move {
 69            delete_credentials.await.log_err();
 70            this.update(&mut cx, |this, cx| {
 71                this.api_key = None;
 72                this.api_key_from_env = false;
 73                cx.notify();
 74            })
 75        })
 76    }
 77
 78    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 79        let settings = &AllLanguageModelSettings::get_global(cx).openai;
 80        let write_credentials =
 81            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
 82
 83        cx.spawn(|this, mut cx| async move {
 84            write_credentials.await?;
 85            this.update(&mut cx, |this, cx| {
 86                this.api_key = Some(api_key);
 87                cx.notify();
 88            })
 89        })
 90    }
 91
 92    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 93        if self.is_authenticated() {
 94            return Task::ready(Ok(()));
 95        }
 96
 97        let api_url = AllLanguageModelSettings::get_global(cx)
 98            .openai
 99            .api_url
100            .clone();
101        cx.spawn(|this, mut cx| async move {
102            let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
103                (api_key, true)
104            } else {
105                let (_, api_key) = cx
106                    .update(|cx| cx.read_credentials(&api_url))?
107                    .await?
108                    .ok_or(AuthenticateError::CredentialsNotFound)?;
109                (
110                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
111                    false,
112                )
113            };
114            this.update(&mut cx, |this, cx| {
115                this.api_key = Some(api_key);
116                this.api_key_from_env = from_env;
117                cx.notify();
118            })?;
119
120            Ok(())
121        })
122    }
123}
124
125impl OpenAiLanguageModelProvider {
126    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
127        let state = cx.new(|cx| State {
128            api_key: None,
129            api_key_from_env: false,
130            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
131                cx.notify();
132            }),
133        });
134
135        Self { http_client, state }
136    }
137}
138
139impl LanguageModelProviderState for OpenAiLanguageModelProvider {
140    type ObservableEntity = State;
141
142    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
143        Some(self.state.clone())
144    }
145}
146
147impl LanguageModelProvider for OpenAiLanguageModelProvider {
148    fn id(&self) -> LanguageModelProviderId {
149        LanguageModelProviderId(PROVIDER_ID.into())
150    }
151
152    fn name(&self) -> LanguageModelProviderName {
153        LanguageModelProviderName(PROVIDER_NAME.into())
154    }
155
156    fn icon(&self) -> IconName {
157        IconName::AiOpenAi
158    }
159
160    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
161        let mut models = BTreeMap::default();
162
163        // Add base models from open_ai::Model::iter()
164        for model in open_ai::Model::iter() {
165            if !matches!(model, open_ai::Model::Custom { .. }) {
166                models.insert(model.id().to_string(), model);
167            }
168        }
169
170        // Override with available models from settings
171        for model in &AllLanguageModelSettings::get_global(cx)
172            .openai
173            .available_models
174        {
175            models.insert(
176                model.name.clone(),
177                open_ai::Model::Custom {
178                    name: model.name.clone(),
179                    display_name: model.display_name.clone(),
180                    max_tokens: model.max_tokens,
181                    max_output_tokens: model.max_output_tokens,
182                    max_completion_tokens: model.max_completion_tokens,
183                },
184            );
185        }
186
187        models
188            .into_values()
189            .map(|model| {
190                Arc::new(OpenAiLanguageModel {
191                    id: LanguageModelId::from(model.id().to_string()),
192                    model,
193                    state: self.state.clone(),
194                    http_client: self.http_client.clone(),
195                    request_limiter: RateLimiter::new(4),
196                }) as Arc<dyn LanguageModel>
197            })
198            .collect()
199    }
200
201    fn is_authenticated(&self, cx: &App) -> bool {
202        self.state.read(cx).is_authenticated()
203    }
204
205    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
206        self.state.update(cx, |state, cx| state.authenticate(cx))
207    }
208
209    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
210        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
211            .into()
212    }
213
214    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
215        self.state.update(cx, |state, cx| state.reset_api_key(cx))
216    }
217}
218
219pub struct OpenAiLanguageModel {
220    id: LanguageModelId,
221    model: open_ai::Model,
222    state: gpui::Entity<State>,
223    http_client: Arc<dyn HttpClient>,
224    request_limiter: RateLimiter,
225}
226
227impl OpenAiLanguageModel {
228    fn stream_completion(
229        &self,
230        request: open_ai::Request,
231        cx: &AsyncApp,
232    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
233    {
234        let http_client = self.http_client.clone();
235        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
236            let settings = &AllLanguageModelSettings::get_global(cx).openai;
237            (state.api_key.clone(), settings.api_url.clone())
238        }) else {
239            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
240        };
241
242        let future = self.request_limiter.stream(async move {
243            let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
244            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
245            let response = request.await?;
246            Ok(response)
247        });
248
249        async move { Ok(future.await?.boxed()) }.boxed()
250    }
251}
252
253impl LanguageModel for OpenAiLanguageModel {
254    fn id(&self) -> LanguageModelId {
255        self.id.clone()
256    }
257
258    fn name(&self) -> LanguageModelName {
259        LanguageModelName::from(self.model.display_name().to_string())
260    }
261
262    fn provider_id(&self) -> LanguageModelProviderId {
263        LanguageModelProviderId(PROVIDER_ID.into())
264    }
265
266    fn provider_name(&self) -> LanguageModelProviderName {
267        LanguageModelProviderName(PROVIDER_NAME.into())
268    }
269
270    fn telemetry_id(&self) -> String {
271        format!("openai/{}", self.model.id())
272    }
273
274    fn max_token_count(&self) -> usize {
275        self.model.max_token_count()
276    }
277
278    fn max_output_tokens(&self) -> Option<u32> {
279        self.model.max_output_tokens()
280    }
281
282    fn count_tokens(
283        &self,
284        request: LanguageModelRequest,
285        cx: &App,
286    ) -> BoxFuture<'static, Result<usize>> {
287        count_open_ai_tokens(request, self.model.clone(), cx)
288    }
289
290    fn stream_completion(
291        &self,
292        request: LanguageModelRequest,
293        cx: &AsyncApp,
294    ) -> BoxFuture<
295        'static,
296        Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
297    > {
298        let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
299        let completions = self.stream_completion(request, cx);
300        async move {
301            Ok(open_ai::extract_text_from_events(completions.await?)
302                .map(|result| result.map(LanguageModelCompletionEvent::Text))
303                .boxed())
304        }
305        .boxed()
306    }
307
308    fn use_any_tool(
309        &self,
310        request: LanguageModelRequest,
311        tool_name: String,
312        tool_description: String,
313        schema: serde_json::Value,
314        cx: &AsyncApp,
315    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
316        let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
317        request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
318            function: FunctionDefinition {
319                name: tool_name.clone(),
320                description: None,
321                parameters: None,
322            },
323        }));
324        request.tools = vec![ToolDefinition::Function {
325            function: FunctionDefinition {
326                name: tool_name.clone(),
327                description: Some(tool_description),
328                parameters: Some(schema),
329            },
330        }];
331
332        let response = self.stream_completion(request, cx);
333        self.request_limiter
334            .run(async move {
335                let response = response.await?;
336                Ok(
337                    open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
338                        .await?
339                        .boxed(),
340                )
341            })
342            .boxed()
343    }
344}
345
346pub fn count_open_ai_tokens(
347    request: LanguageModelRequest,
348    model: open_ai::Model,
349    cx: &App,
350) -> BoxFuture<'static, Result<usize>> {
351    cx.background_spawn(async move {
352        let messages = request
353            .messages
354            .into_iter()
355            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
356                role: match message.role {
357                    Role::User => "user".into(),
358                    Role::Assistant => "assistant".into(),
359                    Role::System => "system".into(),
360                },
361                content: Some(message.string_contents()),
362                name: None,
363                function_call: None,
364            })
365            .collect::<Vec<_>>();
366
367        match model {
368            open_ai::Model::Custom { .. }
369            | open_ai::Model::O1Mini
370            | open_ai::Model::O1
371            | open_ai::Model::O3Mini => tiktoken_rs::num_tokens_from_messages("gpt-4", &messages),
372            _ => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
373        }
374    })
375    .boxed()
376}
377
378struct ConfigurationView {
379    api_key_editor: Entity<Editor>,
380    state: gpui::Entity<State>,
381    load_credentials_task: Option<Task<()>>,
382}
383
384impl ConfigurationView {
385    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
386        let api_key_editor = cx.new(|cx| {
387            let mut editor = Editor::single_line(window, cx);
388            editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
389            editor
390        });
391
392        cx.observe(&state, |_, _, cx| {
393            cx.notify();
394        })
395        .detach();
396
397        let load_credentials_task = Some(cx.spawn_in(window, {
398            let state = state.clone();
399            |this, mut cx| async move {
400                if let Some(task) = state
401                    .update(&mut cx, |state, cx| state.authenticate(cx))
402                    .log_err()
403                {
404                    // We don't log an error, because "not signed in" is also an error.
405                    let _ = task.await;
406                }
407
408                this.update(&mut cx, |this, cx| {
409                    this.load_credentials_task = None;
410                    cx.notify();
411                })
412                .log_err();
413            }
414        }));
415
416        Self {
417            api_key_editor,
418            state,
419            load_credentials_task,
420        }
421    }
422
423    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
424        let api_key = self.api_key_editor.read(cx).text(cx);
425        if api_key.is_empty() {
426            return;
427        }
428
429        let state = self.state.clone();
430        cx.spawn_in(window, |_, mut cx| async move {
431            state
432                .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
433                .await
434        })
435        .detach_and_log_err(cx);
436
437        cx.notify();
438    }
439
440    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
441        self.api_key_editor
442            .update(cx, |editor, cx| editor.set_text("", window, cx));
443
444        let state = self.state.clone();
445        cx.spawn_in(window, |_, mut cx| async move {
446            state
447                .update(&mut cx, |state, cx| state.reset_api_key(cx))?
448                .await
449        })
450        .detach_and_log_err(cx);
451
452        cx.notify();
453    }
454
455    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
456        let settings = ThemeSettings::get_global(cx);
457        let text_style = TextStyle {
458            color: cx.theme().colors().text,
459            font_family: settings.ui_font.family.clone(),
460            font_features: settings.ui_font.features.clone(),
461            font_fallbacks: settings.ui_font.fallbacks.clone(),
462            font_size: rems(0.875).into(),
463            font_weight: settings.ui_font.weight,
464            font_style: FontStyle::Normal,
465            line_height: relative(1.3),
466            white_space: WhiteSpace::Normal,
467            ..Default::default()
468        };
469        EditorElement::new(
470            &self.api_key_editor,
471            EditorStyle {
472                background: cx.theme().colors().editor_background,
473                local_player: cx.theme().players().local(),
474                text: text_style,
475                ..Default::default()
476            },
477        )
478    }
479
480    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
481        !self.state.read(cx).is_authenticated()
482    }
483}
484
485impl Render for ConfigurationView {
486    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
487        const OPENAI_CONSOLE_URL: &str = "https://platform.openai.com/api-keys";
488        const INSTRUCTIONS: [&str; 4] = [
489            "To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:",
490            " - Create one by visiting:",
491            " - Ensure your OpenAI account has credits",
492            " - Paste your API key below and hit enter to start using the assistant",
493        ];
494
495        let env_var_set = self.state.read(cx).api_key_from_env;
496
497        if self.load_credentials_task.is_some() {
498            div().child(Label::new("Loading credentials...")).into_any()
499        } else if self.should_render_editor(cx) {
500            v_flex()
501                .size_full()
502                .on_action(cx.listener(Self::save_api_key))
503                .child(Label::new(INSTRUCTIONS[0]))
504                .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
505                    Button::new("openai_console", OPENAI_CONSOLE_URL)
506                        .style(ButtonStyle::Subtle)
507                        .icon(IconName::ArrowUpRight)
508                        .icon_size(IconSize::XSmall)
509                        .icon_color(Color::Muted)
510                        .on_click(move |_, _, cx| cx.open_url(OPENAI_CONSOLE_URL))
511                    )
512                )
513                .children(
514                    (2..INSTRUCTIONS.len()).map(|n|
515                        Label::new(INSTRUCTIONS[n])).collect::<Vec<_>>())
516                .child(
517                    h_flex()
518                        .w_full()
519                        .my_2()
520                        .px_2()
521                        .py_1()
522                        .bg(cx.theme().colors().editor_background)
523                        .border_1()
524                        .border_color(cx.theme().colors().border_variant)
525                        .rounded_md()
526                        .child(self.render_api_key_editor(cx)),
527                )
528                .child(
529                    Label::new(
530                        format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
531                    )
532                    .size(LabelSize::Small),
533                )
534                .child(
535                    Label::new(
536                        "Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
537                    )
538                    .size(LabelSize::Small),
539                )
540                .into_any()
541        } else {
542            h_flex()
543                .size_full()
544                .justify_between()
545                .child(
546                    h_flex()
547                        .gap_1()
548                        .child(Icon::new(IconName::Check).color(Color::Success))
549                        .child(Label::new(if env_var_set {
550                            format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
551                        } else {
552                            "API key configured.".to_string()
553                        })),
554                )
555                .child(
556                    Button::new("reset-key", "Reset key")
557                        .icon(Some(IconName::Trash))
558                        .icon_size(IconSize::Small)
559                        .icon_position(IconPosition::Start)
560                        .disabled(env_var_set)
561                        .when(env_var_set, |this| {
562                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
563                        })
564                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
565                )
566                .into_any()
567        }
568    }
569}