open_ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  7    WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use open_ai::stream_completion;
 11use settings::{Settings, SettingsStore};
 12use std::{sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 21    LanguageModelProviderState, LanguageModelRequest, Role,
 22};
 23
 24const PROVIDER_ID: &str = "openai";
 25const PROVIDER_NAME: &str = "OpenAI";
 26
 27#[derive(Default, Clone, Debug, PartialEq)]
 28pub struct OpenAiSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31    pub available_models: Vec<open_ai::Model>,
 32}
 33
 34pub struct OpenAiLanguageModelProvider {
 35    http_client: Arc<dyn HttpClient>,
 36    state: gpui::Model<State>,
 37}
 38
 39struct State {
 40    api_key: Option<String>,
 41    _subscription: Subscription,
 42}
 43
 44impl OpenAiLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 49                cx.notify();
 50            }),
 51        });
 52
 53        Self { http_client, state }
 54    }
 55}
 56
 57impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 58    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 59        Some(cx.observe(&self.state, |_, _, cx| {
 60            cx.notify();
 61        }))
 62    }
 63}
 64
 65impl LanguageModelProvider for OpenAiLanguageModelProvider {
 66    fn id(&self) -> LanguageModelProviderId {
 67        LanguageModelProviderId(PROVIDER_ID.into())
 68    }
 69
 70    fn name(&self) -> LanguageModelProviderName {
 71        LanguageModelProviderName(PROVIDER_NAME.into())
 72    }
 73
 74    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 75        let mut models = BTreeMap::default();
 76
 77        // Add base models from open_ai::Model::iter()
 78        for model in open_ai::Model::iter() {
 79            if !matches!(model, open_ai::Model::Custom { .. }) {
 80                models.insert(model.id().to_string(), model);
 81            }
 82        }
 83
 84        // Override with available models from settings
 85        for model in &AllLanguageModelSettings::get_global(cx)
 86            .openai
 87            .available_models
 88        {
 89            models.insert(model.id().to_string(), model.clone());
 90        }
 91
 92        models
 93            .into_values()
 94            .map(|model| {
 95                Arc::new(OpenAiLanguageModel {
 96                    id: LanguageModelId::from(model.id().to_string()),
 97                    model,
 98                    state: self.state.clone(),
 99                    http_client: self.http_client.clone(),
100                }) as Arc<dyn LanguageModel>
101            })
102            .collect()
103    }
104
105    fn is_authenticated(&self, cx: &AppContext) -> bool {
106        self.state.read(cx).api_key.is_some()
107    }
108
109    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
110        if self.is_authenticated(cx) {
111            Task::ready(Ok(()))
112        } else {
113            let api_url = AllLanguageModelSettings::get_global(cx)
114                .openai
115                .api_url
116                .clone();
117            let state = self.state.clone();
118            cx.spawn(|mut cx| async move {
119                let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
120                    api_key
121                } else {
122                    let (_, api_key) = cx
123                        .update(|cx| cx.read_credentials(&api_url))?
124                        .await?
125                        .ok_or_else(|| anyhow!("credentials not found"))?;
126                    String::from_utf8(api_key)?
127                };
128                state.update(&mut cx, |this, cx| {
129                    this.api_key = Some(api_key);
130                    cx.notify();
131                })
132            })
133        }
134    }
135
136    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
137        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
138            .into()
139    }
140
141    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
142        let settings = &AllLanguageModelSettings::get_global(cx).openai;
143        let delete_credentials = cx.delete_credentials(&settings.api_url);
144        let state = self.state.clone();
145        cx.spawn(|mut cx| async move {
146            delete_credentials.await.log_err();
147            state.update(&mut cx, |this, cx| {
148                this.api_key = None;
149                cx.notify();
150            })
151        })
152    }
153}
154
155pub struct OpenAiLanguageModel {
156    id: LanguageModelId,
157    model: open_ai::Model,
158    state: gpui::Model<State>,
159    http_client: Arc<dyn HttpClient>,
160}
161
162impl LanguageModel for OpenAiLanguageModel {
163    fn id(&self) -> LanguageModelId {
164        self.id.clone()
165    }
166
167    fn name(&self) -> LanguageModelName {
168        LanguageModelName::from(self.model.display_name().to_string())
169    }
170
171    fn provider_id(&self) -> LanguageModelProviderId {
172        LanguageModelProviderId(PROVIDER_ID.into())
173    }
174
175    fn provider_name(&self) -> LanguageModelProviderName {
176        LanguageModelProviderName(PROVIDER_NAME.into())
177    }
178
179    fn telemetry_id(&self) -> String {
180        format!("openai/{}", self.model.id())
181    }
182
183    fn max_token_count(&self) -> usize {
184        self.model.max_token_count()
185    }
186
187    fn count_tokens(
188        &self,
189        request: LanguageModelRequest,
190        cx: &AppContext,
191    ) -> BoxFuture<'static, Result<usize>> {
192        count_open_ai_tokens(request, self.model.clone(), cx)
193    }
194
195    fn stream_completion(
196        &self,
197        request: LanguageModelRequest,
198        cx: &AsyncAppContext,
199    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
200        let request = request.into_open_ai(self.model.id().into());
201
202        let http_client = self.http_client.clone();
203        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
204            let settings = &AllLanguageModelSettings::get_global(cx).openai;
205            (
206                state.api_key.clone(),
207                settings.api_url.clone(),
208                settings.low_speed_timeout,
209            )
210        }) else {
211            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
212        };
213
214        async move {
215            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
216            let request = stream_completion(
217                http_client.as_ref(),
218                &api_url,
219                &api_key,
220                request,
221                low_speed_timeout,
222            );
223            let response = request.await?;
224            Ok(open_ai::extract_text_from_events(response).boxed())
225        }
226        .boxed()
227    }
228}
229
230pub fn count_open_ai_tokens(
231    request: LanguageModelRequest,
232    model: open_ai::Model,
233    cx: &AppContext,
234) -> BoxFuture<'static, Result<usize>> {
235    cx.background_executor()
236        .spawn(async move {
237            let messages = request
238                .messages
239                .into_iter()
240                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
241                    role: match message.role {
242                        Role::User => "user".into(),
243                        Role::Assistant => "assistant".into(),
244                        Role::System => "system".into(),
245                    },
246                    content: Some(message.content),
247                    name: None,
248                    function_call: None,
249                })
250                .collect::<Vec<_>>();
251
252            if let open_ai::Model::Custom { .. } = model {
253                tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
254            } else {
255                tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
256            }
257        })
258        .boxed()
259}
260
261struct AuthenticationPrompt {
262    api_key: View<Editor>,
263    state: gpui::Model<State>,
264}
265
266impl AuthenticationPrompt {
267    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
268        Self {
269            api_key: cx.new_view(|cx| {
270                let mut editor = Editor::single_line(cx);
271                editor.set_placeholder_text(
272                    "sk-000000000000000000000000000000000000000000000000",
273                    cx,
274                );
275                editor
276            }),
277            state,
278        }
279    }
280
281    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
282        let api_key = self.api_key.read(cx).text(cx);
283        if api_key.is_empty() {
284            return;
285        }
286
287        let settings = &AllLanguageModelSettings::get_global(cx).openai;
288        let write_credentials =
289            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
290        let state = self.state.clone();
291        cx.spawn(|_, mut cx| async move {
292            write_credentials.await?;
293            state.update(&mut cx, |this, cx| {
294                this.api_key = Some(api_key);
295                cx.notify();
296            })
297        })
298        .detach_and_log_err(cx);
299    }
300
301    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
302        let settings = ThemeSettings::get_global(cx);
303        let text_style = TextStyle {
304            color: cx.theme().colors().text,
305            font_family: settings.ui_font.family.clone(),
306            font_features: settings.ui_font.features.clone(),
307            font_fallbacks: settings.ui_font.fallbacks.clone(),
308            font_size: rems(0.875).into(),
309            font_weight: settings.ui_font.weight,
310            font_style: FontStyle::Normal,
311            line_height: relative(1.3),
312            background_color: None,
313            underline: None,
314            strikethrough: None,
315            white_space: WhiteSpace::Normal,
316        };
317        EditorElement::new(
318            &self.api_key,
319            EditorStyle {
320                background: cx.theme().colors().editor_background,
321                local_player: cx.theme().players().local(),
322                text: text_style,
323                ..Default::default()
324            },
325        )
326    }
327}
328
329impl Render for AuthenticationPrompt {
330    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
331        const INSTRUCTIONS: [&str; 6] = [
332            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
333            " - You can create an API key at: platform.openai.com/api-keys",
334            " - Make sure your OpenAI account has credits",
335            " - Having a subscription for another service like GitHub Copilot won't work.",
336            "",
337            "Paste your OpenAI API key below and hit enter to use the assistant:",
338        ];
339
340        v_flex()
341            .p_4()
342            .size_full()
343            .on_action(cx.listener(Self::save_api_key))
344            .children(
345                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
346            )
347            .child(
348                h_flex()
349                    .w_full()
350                    .my_2()
351                    .px_2()
352                    .py_1()
353                    .bg(cx.theme().colors().editor_background)
354                    .rounded_md()
355                    .child(self.render_api_key_editor(cx)),
356            )
357            .child(
358                Label::new(
359                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
360                )
361                .size(LabelSize::Small),
362            )
363            .child(
364                h_flex()
365                    .gap_2()
366                    .child(Label::new("Click on").size(LabelSize::Small))
367                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
368                    .child(
369                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
370                    ),
371            )
372            .into_any()
373    }
374}