open_ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  7    WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use open_ai::stream_completion;
 11use settings::{Settings, SettingsStore};
 12use std::{future, sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 21    LanguageModelProviderState, LanguageModelRequest, Role,
 22};
 23
 24const PROVIDER_ID: &str = "openai";
 25const PROVIDER_NAME: &str = "OpenAI";
 26
 27#[derive(Default, Clone, Debug, PartialEq)]
 28pub struct OpenAiSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31    pub available_models: Vec<open_ai::Model>,
 32}
 33
 34pub struct OpenAiLanguageModelProvider {
 35    http_client: Arc<dyn HttpClient>,
 36    state: gpui::Model<State>,
 37}
 38
 39struct State {
 40    api_key: Option<String>,
 41    _subscription: Subscription,
 42}
 43
 44impl OpenAiLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 49                cx.notify();
 50            }),
 51        });
 52
 53        Self { http_client, state }
 54    }
 55}
 56
 57impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 58    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 59        Some(cx.observe(&self.state, |_, _, cx| {
 60            cx.notify();
 61        }))
 62    }
 63}
 64
 65impl LanguageModelProvider for OpenAiLanguageModelProvider {
 66    fn id(&self) -> LanguageModelProviderId {
 67        LanguageModelProviderId(PROVIDER_ID.into())
 68    }
 69
 70    fn name(&self) -> LanguageModelProviderName {
 71        LanguageModelProviderName(PROVIDER_NAME.into())
 72    }
 73
 74    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 75        let mut models = BTreeMap::default();
 76
 77        // Add base models from open_ai::Model::iter()
 78        for model in open_ai::Model::iter() {
 79            if !matches!(model, open_ai::Model::Custom { .. }) {
 80                models.insert(model.id().to_string(), model);
 81            }
 82        }
 83
 84        // Override with available models from settings
 85        for model in &AllLanguageModelSettings::get_global(cx)
 86            .openai
 87            .available_models
 88        {
 89            models.insert(model.id().to_string(), model.clone());
 90        }
 91
 92        models
 93            .into_values()
 94            .map(|model| {
 95                Arc::new(OpenAiLanguageModel {
 96                    id: LanguageModelId::from(model.id().to_string()),
 97                    model,
 98                    state: self.state.clone(),
 99                    http_client: self.http_client.clone(),
100                }) as Arc<dyn LanguageModel>
101            })
102            .collect()
103    }
104
105    fn is_authenticated(&self, cx: &AppContext) -> bool {
106        self.state.read(cx).api_key.is_some()
107    }
108
109    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
110        if self.is_authenticated(cx) {
111            Task::ready(Ok(()))
112        } else {
113            let api_url = AllLanguageModelSettings::get_global(cx)
114                .openai
115                .api_url
116                .clone();
117            let state = self.state.clone();
118            cx.spawn(|mut cx| async move {
119                let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
120                    api_key
121                } else {
122                    let (_, api_key) = cx
123                        .update(|cx| cx.read_credentials(&api_url))?
124                        .await?
125                        .ok_or_else(|| anyhow!("credentials not found"))?;
126                    String::from_utf8(api_key)?
127                };
128                state.update(&mut cx, |this, cx| {
129                    this.api_key = Some(api_key);
130                    cx.notify();
131                })
132            })
133        }
134    }
135
136    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
137        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
138            .into()
139    }
140
141    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
142        let settings = &AllLanguageModelSettings::get_global(cx).openai;
143        let delete_credentials = cx.delete_credentials(&settings.api_url);
144        let state = self.state.clone();
145        cx.spawn(|mut cx| async move {
146            delete_credentials.await.log_err();
147            state.update(&mut cx, |this, cx| {
148                this.api_key = None;
149                cx.notify();
150            })
151        })
152    }
153}
154
155pub struct OpenAiLanguageModel {
156    id: LanguageModelId,
157    model: open_ai::Model,
158    state: gpui::Model<State>,
159    http_client: Arc<dyn HttpClient>,
160}
161
162impl LanguageModel for OpenAiLanguageModel {
163    fn id(&self) -> LanguageModelId {
164        self.id.clone()
165    }
166
167    fn name(&self) -> LanguageModelName {
168        LanguageModelName::from(self.model.display_name().to_string())
169    }
170
171    fn provider_id(&self) -> LanguageModelProviderId {
172        LanguageModelProviderId(PROVIDER_ID.into())
173    }
174
175    fn provider_name(&self) -> LanguageModelProviderName {
176        LanguageModelProviderName(PROVIDER_NAME.into())
177    }
178
179    fn telemetry_id(&self) -> String {
180        format!("openai/{}", self.model.id())
181    }
182
183    fn max_token_count(&self) -> usize {
184        self.model.max_token_count()
185    }
186
187    fn count_tokens(
188        &self,
189        request: LanguageModelRequest,
190        cx: &AppContext,
191    ) -> BoxFuture<'static, Result<usize>> {
192        count_open_ai_tokens(request, self.model.clone(), cx)
193    }
194
195    fn stream_completion(
196        &self,
197        request: LanguageModelRequest,
198        cx: &AsyncAppContext,
199    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
200        let request = request.into_open_ai(self.model.id().into());
201
202        let http_client = self.http_client.clone();
203        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
204            let settings = &AllLanguageModelSettings::get_global(cx).openai;
205            (
206                state.api_key.clone(),
207                settings.api_url.clone(),
208                settings.low_speed_timeout,
209            )
210        }) else {
211            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
212        };
213
214        async move {
215            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
216            let request = stream_completion(
217                http_client.as_ref(),
218                &api_url,
219                &api_key,
220                request,
221                low_speed_timeout,
222            );
223            let response = request.await?;
224            Ok(open_ai::extract_text_from_events(response).boxed())
225        }
226        .boxed()
227    }
228
229    fn use_tool(
230        &self,
231        _request: LanguageModelRequest,
232        _name: String,
233        _description: String,
234        _schema: serde_json::Value,
235        _cx: &AsyncAppContext,
236    ) -> BoxFuture<'static, Result<serde_json::Value>> {
237        future::ready(Err(anyhow!("not implemented"))).boxed()
238    }
239}
240
241pub fn count_open_ai_tokens(
242    request: LanguageModelRequest,
243    model: open_ai::Model,
244    cx: &AppContext,
245) -> BoxFuture<'static, Result<usize>> {
246    cx.background_executor()
247        .spawn(async move {
248            let messages = request
249                .messages
250                .into_iter()
251                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
252                    role: match message.role {
253                        Role::User => "user".into(),
254                        Role::Assistant => "assistant".into(),
255                        Role::System => "system".into(),
256                    },
257                    content: Some(message.content),
258                    name: None,
259                    function_call: None,
260                })
261                .collect::<Vec<_>>();
262
263            if let open_ai::Model::Custom { .. } = model {
264                tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
265            } else {
266                tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
267            }
268        })
269        .boxed()
270}
271
272struct AuthenticationPrompt {
273    api_key: View<Editor>,
274    state: gpui::Model<State>,
275}
276
277impl AuthenticationPrompt {
278    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
279        Self {
280            api_key: cx.new_view(|cx| {
281                let mut editor = Editor::single_line(cx);
282                editor.set_placeholder_text(
283                    "sk-000000000000000000000000000000000000000000000000",
284                    cx,
285                );
286                editor
287            }),
288            state,
289        }
290    }
291
292    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
293        let api_key = self.api_key.read(cx).text(cx);
294        if api_key.is_empty() {
295            return;
296        }
297
298        let settings = &AllLanguageModelSettings::get_global(cx).openai;
299        let write_credentials =
300            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
301        let state = self.state.clone();
302        cx.spawn(|_, mut cx| async move {
303            write_credentials.await?;
304            state.update(&mut cx, |this, cx| {
305                this.api_key = Some(api_key);
306                cx.notify();
307            })
308        })
309        .detach_and_log_err(cx);
310    }
311
312    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
313        let settings = ThemeSettings::get_global(cx);
314        let text_style = TextStyle {
315            color: cx.theme().colors().text,
316            font_family: settings.ui_font.family.clone(),
317            font_features: settings.ui_font.features.clone(),
318            font_fallbacks: settings.ui_font.fallbacks.clone(),
319            font_size: rems(0.875).into(),
320            font_weight: settings.ui_font.weight,
321            font_style: FontStyle::Normal,
322            line_height: relative(1.3),
323            background_color: None,
324            underline: None,
325            strikethrough: None,
326            white_space: WhiteSpace::Normal,
327        };
328        EditorElement::new(
329            &self.api_key,
330            EditorStyle {
331                background: cx.theme().colors().editor_background,
332                local_player: cx.theme().players().local(),
333                text: text_style,
334                ..Default::default()
335            },
336        )
337    }
338}
339
340impl Render for AuthenticationPrompt {
341    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
342        const INSTRUCTIONS: [&str; 6] = [
343            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
344            " - You can create an API key at: platform.openai.com/api-keys",
345            " - Make sure your OpenAI account has credits",
346            " - Having a subscription for another service like GitHub Copilot won't work.",
347            "",
348            "Paste your OpenAI API key below and hit enter to use the assistant:",
349        ];
350
351        v_flex()
352            .p_4()
353            .size_full()
354            .on_action(cx.listener(Self::save_api_key))
355            .children(
356                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
357            )
358            .child(
359                h_flex()
360                    .w_full()
361                    .my_2()
362                    .px_2()
363                    .py_1()
364                    .bg(cx.theme().colors().editor_background)
365                    .rounded_md()
366                    .child(self.render_api_key_editor(cx)),
367            )
368            .child(
369                Label::new(
370                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
371                )
372                .size(LabelSize::Small),
373            )
374            .child(
375                h_flex()
376                    .gap_2()
377                    .child(Label::new("Click on").size(LabelSize::Small))
378                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
379                    .child(
380                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
381                    ),
382            )
383            .into_any()
384    }
385}