open_ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
  7    View, WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use open_ai::{
 11    stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
 12};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::{Settings, SettingsStore};
 16use std::{sync::Arc, time::Duration};
 17use strum::IntoEnumIterator;
 18use theme::ThemeSettings;
 19use ui::{prelude::*, Icon, IconName, Tooltip};
 20use util::ResultExt;
 21
 22use crate::LanguageModelCompletionEvent;
 23use crate::{
 24    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 25    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 26    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 27};
 28
 29const PROVIDER_ID: &str = "openai";
 30const PROVIDER_NAME: &str = "OpenAI";
 31
 32#[derive(Default, Clone, Debug, PartialEq)]
 33pub struct OpenAiSettings {
 34    pub api_url: String,
 35    pub low_speed_timeout: Option<Duration>,
 36    pub available_models: Vec<AvailableModel>,
 37    pub needs_setting_migration: bool,
 38}
 39
 40#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 41pub struct AvailableModel {
 42    pub name: String,
 43    pub display_name: Option<String>,
 44    pub max_tokens: usize,
 45    pub max_output_tokens: Option<u32>,
 46    pub max_completion_tokens: Option<u32>,
 47}
 48
 49pub struct OpenAiLanguageModelProvider {
 50    http_client: Arc<dyn HttpClient>,
 51    state: gpui::Model<State>,
 52}
 53
 54pub struct State {
 55    api_key: Option<String>,
 56    api_key_from_env: bool,
 57    _subscription: Subscription,
 58}
 59
 60const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
 61
 62impl State {
 63    fn is_authenticated(&self) -> bool {
 64        self.api_key.is_some()
 65    }
 66
 67    fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 68        let settings = &AllLanguageModelSettings::get_global(cx).openai;
 69        let delete_credentials = cx.delete_credentials(&settings.api_url);
 70        cx.spawn(|this, mut cx| async move {
 71            delete_credentials.await.log_err();
 72            this.update(&mut cx, |this, cx| {
 73                this.api_key = None;
 74                this.api_key_from_env = false;
 75                cx.notify();
 76            })
 77        })
 78    }
 79
 80    fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 81        let settings = &AllLanguageModelSettings::get_global(cx).openai;
 82        let write_credentials =
 83            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
 84
 85        cx.spawn(|this, mut cx| async move {
 86            write_credentials.await?;
 87            this.update(&mut cx, |this, cx| {
 88                this.api_key = Some(api_key);
 89                cx.notify();
 90            })
 91        })
 92    }
 93
 94    fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 95        if self.is_authenticated() {
 96            Task::ready(Ok(()))
 97        } else {
 98            let api_url = AllLanguageModelSettings::get_global(cx)
 99                .openai
100                .api_url
101                .clone();
102            cx.spawn(|this, mut cx| async move {
103                let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
104                    (api_key, true)
105                } else {
106                    let (_, api_key) = cx
107                        .update(|cx| cx.read_credentials(&api_url))?
108                        .await?
109                        .ok_or_else(|| anyhow!("credentials not found"))?;
110                    (String::from_utf8(api_key)?, false)
111                };
112                this.update(&mut cx, |this, cx| {
113                    this.api_key = Some(api_key);
114                    this.api_key_from_env = from_env;
115                    cx.notify();
116                })
117            })
118        }
119    }
120}
121
122impl OpenAiLanguageModelProvider {
123    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
124        let state = cx.new_model(|cx| State {
125            api_key: None,
126            api_key_from_env: false,
127            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
128                cx.notify();
129            }),
130        });
131
132        Self { http_client, state }
133    }
134}
135
136impl LanguageModelProviderState for OpenAiLanguageModelProvider {
137    type ObservableEntity = State;
138
139    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
140        Some(self.state.clone())
141    }
142}
143
144impl LanguageModelProvider for OpenAiLanguageModelProvider {
145    fn id(&self) -> LanguageModelProviderId {
146        LanguageModelProviderId(PROVIDER_ID.into())
147    }
148
149    fn name(&self) -> LanguageModelProviderName {
150        LanguageModelProviderName(PROVIDER_NAME.into())
151    }
152
153    fn icon(&self) -> IconName {
154        IconName::AiOpenAi
155    }
156
157    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
158        let mut models = BTreeMap::default();
159
160        // Add base models from open_ai::Model::iter()
161        for model in open_ai::Model::iter() {
162            if !matches!(model, open_ai::Model::Custom { .. }) {
163                models.insert(model.id().to_string(), model);
164            }
165        }
166
167        // Override with available models from settings
168        for model in &AllLanguageModelSettings::get_global(cx)
169            .openai
170            .available_models
171        {
172            models.insert(
173                model.name.clone(),
174                open_ai::Model::Custom {
175                    name: model.name.clone(),
176                    display_name: model.display_name.clone(),
177                    max_tokens: model.max_tokens,
178                    max_output_tokens: model.max_output_tokens,
179                    max_completion_tokens: model.max_completion_tokens,
180                },
181            );
182        }
183
184        models
185            .into_values()
186            .map(|model| {
187                Arc::new(OpenAiLanguageModel {
188                    id: LanguageModelId::from(model.id().to_string()),
189                    model,
190                    state: self.state.clone(),
191                    http_client: self.http_client.clone(),
192                    request_limiter: RateLimiter::new(4),
193                }) as Arc<dyn LanguageModel>
194            })
195            .collect()
196    }
197
198    fn is_authenticated(&self, cx: &AppContext) -> bool {
199        self.state.read(cx).is_authenticated()
200    }
201
202    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
203        self.state.update(cx, |state, cx| state.authenticate(cx))
204    }
205
206    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
207        cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
208            .into()
209    }
210
211    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
212        self.state.update(cx, |state, cx| state.reset_api_key(cx))
213    }
214}
215
216pub struct OpenAiLanguageModel {
217    id: LanguageModelId,
218    model: open_ai::Model,
219    state: gpui::Model<State>,
220    http_client: Arc<dyn HttpClient>,
221    request_limiter: RateLimiter,
222}
223
224impl OpenAiLanguageModel {
225    fn stream_completion(
226        &self,
227        request: open_ai::Request,
228        cx: &AsyncAppContext,
229    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
230    {
231        let http_client = self.http_client.clone();
232        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
233            let settings = &AllLanguageModelSettings::get_global(cx).openai;
234            (
235                state.api_key.clone(),
236                settings.api_url.clone(),
237                settings.low_speed_timeout,
238            )
239        }) else {
240            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
241        };
242
243        let future = self.request_limiter.stream(async move {
244            let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
245            let request = stream_completion(
246                http_client.as_ref(),
247                &api_url,
248                &api_key,
249                request,
250                low_speed_timeout,
251            );
252            let response = request.await?;
253            Ok(response)
254        });
255
256        async move { Ok(future.await?.boxed()) }.boxed()
257    }
258}
259
260impl LanguageModel for OpenAiLanguageModel {
261    fn id(&self) -> LanguageModelId {
262        self.id.clone()
263    }
264
265    fn name(&self) -> LanguageModelName {
266        LanguageModelName::from(self.model.display_name().to_string())
267    }
268
269    fn provider_id(&self) -> LanguageModelProviderId {
270        LanguageModelProviderId(PROVIDER_ID.into())
271    }
272
273    fn provider_name(&self) -> LanguageModelProviderName {
274        LanguageModelProviderName(PROVIDER_NAME.into())
275    }
276
277    fn telemetry_id(&self) -> String {
278        format!("openai/{}", self.model.id())
279    }
280
281    fn max_token_count(&self) -> usize {
282        self.model.max_token_count()
283    }
284
285    fn max_output_tokens(&self) -> Option<u32> {
286        self.model.max_output_tokens()
287    }
288
289    fn count_tokens(
290        &self,
291        request: LanguageModelRequest,
292        cx: &AppContext,
293    ) -> BoxFuture<'static, Result<usize>> {
294        count_open_ai_tokens(request, self.model.clone(), cx)
295    }
296
297    fn stream_completion(
298        &self,
299        request: LanguageModelRequest,
300        cx: &AsyncAppContext,
301    ) -> BoxFuture<
302        'static,
303        Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
304    > {
305        let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
306        let completions = self.stream_completion(request, cx);
307        async move {
308            Ok(open_ai::extract_text_from_events(completions.await?)
309                .map(|result| result.map(LanguageModelCompletionEvent::Text))
310                .boxed())
311        }
312        .boxed()
313    }
314
315    fn use_any_tool(
316        &self,
317        request: LanguageModelRequest,
318        tool_name: String,
319        tool_description: String,
320        schema: serde_json::Value,
321        cx: &AsyncAppContext,
322    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
323        let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
324        request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
325            function: FunctionDefinition {
326                name: tool_name.clone(),
327                description: None,
328                parameters: None,
329            },
330        }));
331        request.tools = vec![ToolDefinition::Function {
332            function: FunctionDefinition {
333                name: tool_name.clone(),
334                description: Some(tool_description),
335                parameters: Some(schema),
336            },
337        }];
338
339        let response = self.stream_completion(request, cx);
340        self.request_limiter
341            .run(async move {
342                let response = response.await?;
343                Ok(
344                    open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
345                        .await?
346                        .boxed(),
347                )
348            })
349            .boxed()
350    }
351}
352
353pub fn count_open_ai_tokens(
354    request: LanguageModelRequest,
355    model: open_ai::Model,
356    cx: &AppContext,
357) -> BoxFuture<'static, Result<usize>> {
358    cx.background_executor()
359        .spawn(async move {
360            let messages = request
361                .messages
362                .into_iter()
363                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
364                    role: match message.role {
365                        Role::User => "user".into(),
366                        Role::Assistant => "assistant".into(),
367                        Role::System => "system".into(),
368                    },
369                    content: Some(message.string_contents()),
370                    name: None,
371                    function_call: None,
372                })
373                .collect::<Vec<_>>();
374
375            match model {
376                open_ai::Model::Custom { .. }
377                | open_ai::Model::O1Mini
378                | open_ai::Model::O1Preview => {
379                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
380                }
381                _ => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
382            }
383        })
384        .boxed()
385}
386
387struct ConfigurationView {
388    api_key_editor: View<Editor>,
389    state: gpui::Model<State>,
390    load_credentials_task: Option<Task<()>>,
391}
392
393impl ConfigurationView {
394    fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
395        let api_key_editor = cx.new_view(|cx| {
396            let mut editor = Editor::single_line(cx);
397            editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
398            editor
399        });
400
401        cx.observe(&state, |_, _, cx| {
402            cx.notify();
403        })
404        .detach();
405
406        let load_credentials_task = Some(cx.spawn({
407            let state = state.clone();
408            |this, mut cx| async move {
409                if let Some(task) = state
410                    .update(&mut cx, |state, cx| state.authenticate(cx))
411                    .log_err()
412                {
413                    // We don't log an error, because "not signed in" is also an error.
414                    let _ = task.await;
415                }
416
417                this.update(&mut cx, |this, cx| {
418                    this.load_credentials_task = None;
419                    cx.notify();
420                })
421                .log_err();
422            }
423        }));
424
425        Self {
426            api_key_editor,
427            state,
428            load_credentials_task,
429        }
430    }
431
432    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
433        let api_key = self.api_key_editor.read(cx).text(cx);
434        if api_key.is_empty() {
435            return;
436        }
437
438        let state = self.state.clone();
439        cx.spawn(|_, mut cx| async move {
440            state
441                .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
442                .await
443        })
444        .detach_and_log_err(cx);
445
446        cx.notify();
447    }
448
449    fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
450        self.api_key_editor
451            .update(cx, |editor, cx| editor.set_text("", cx));
452
453        let state = self.state.clone();
454        cx.spawn(|_, mut cx| async move {
455            state
456                .update(&mut cx, |state, cx| state.reset_api_key(cx))?
457                .await
458        })
459        .detach_and_log_err(cx);
460
461        cx.notify();
462    }
463
464    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
465        let settings = ThemeSettings::get_global(cx);
466        let text_style = TextStyle {
467            color: cx.theme().colors().text,
468            font_family: settings.ui_font.family.clone(),
469            font_features: settings.ui_font.features.clone(),
470            font_fallbacks: settings.ui_font.fallbacks.clone(),
471            font_size: rems(0.875).into(),
472            font_weight: settings.ui_font.weight,
473            font_style: FontStyle::Normal,
474            line_height: relative(1.3),
475            background_color: None,
476            underline: None,
477            strikethrough: None,
478            white_space: WhiteSpace::Normal,
479            truncate: None,
480        };
481        EditorElement::new(
482            &self.api_key_editor,
483            EditorStyle {
484                background: cx.theme().colors().editor_background,
485                local_player: cx.theme().players().local(),
486                text: text_style,
487                ..Default::default()
488            },
489        )
490    }
491
492    fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
493        !self.state.read(cx).is_authenticated()
494    }
495}
496
497impl Render for ConfigurationView {
498    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
499        const OPENAI_CONSOLE_URL: &str = "https://platform.openai.com/api-keys";
500        const INSTRUCTIONS: [&str; 4] = [
501            "To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:",
502            " - Create one by visiting:",
503            " - Ensure your OpenAI account has credits",
504            " - Paste your API key below and hit enter to start using the assistant",
505        ];
506
507        let env_var_set = self.state.read(cx).api_key_from_env;
508
509        if self.load_credentials_task.is_some() {
510            div().child(Label::new("Loading credentials...")).into_any()
511        } else if self.should_render_editor(cx) {
512            v_flex()
513                .size_full()
514                .on_action(cx.listener(Self::save_api_key))
515                .child(Label::new(INSTRUCTIONS[0]))
516                .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
517                    Button::new("openai_console", OPENAI_CONSOLE_URL)
518                        .style(ButtonStyle::Subtle)
519                        .icon(IconName::ExternalLink)
520                        .icon_size(IconSize::XSmall)
521                        .icon_color(Color::Muted)
522                        .on_click(move |_, cx| cx.open_url(OPENAI_CONSOLE_URL))
523                    )
524                )
525                .children(
526                    (2..INSTRUCTIONS.len()).map(|n|
527                        Label::new(INSTRUCTIONS[n])).collect::<Vec<_>>())
528                .child(
529                    h_flex()
530                        .w_full()
531                        .my_2()
532                        .px_2()
533                        .py_1()
534                        .bg(cx.theme().colors().editor_background)
535                        .rounded_md()
536                        .child(self.render_api_key_editor(cx)),
537                )
538                .child(
539                    Label::new(
540                        format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
541                    )
542                    .size(LabelSize::Small),
543                )
544                .child(
545                    Label::new(
546                        "Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
547                    )
548                    .size(LabelSize::Small),
549                )
550                .into_any()
551        } else {
552            h_flex()
553                .size_full()
554                .justify_between()
555                .child(
556                    h_flex()
557                        .gap_1()
558                        .child(Icon::new(IconName::Check).color(Color::Success))
559                        .child(Label::new(if env_var_set {
560                            format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
561                        } else {
562                            "API key configured.".to_string()
563                        })),
564                )
565                .child(
566                    Button::new("reset-key", "Reset key")
567                        .icon(Some(IconName::Trash))
568                        .icon_size(IconSize::Small)
569                        .icon_position(IconPosition::Start)
570                        .disabled(env_var_set)
571                        .when(env_var_set, |this| {
572                            this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."), cx))
573                        })
574                        .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
575                )
576                .into_any()
577        }
578    }
579}