open_ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::BTreeMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  7    WhiteSpace,
  8};
  9use http_client::HttpClient;
 10use open_ai::{stream_completion, Request, RequestMessage};
 11use settings::{Settings, SettingsStore};
 12use std::{sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 21    LanguageModelProviderState, LanguageModelRequest, Role,
 22};
 23
 24const PROVIDER_ID: &str = "openai";
 25const PROVIDER_NAME: &str = "OpenAI";
 26
 27#[derive(Default, Clone, Debug, PartialEq)]
 28pub struct OpenAiSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31    pub available_models: Vec<open_ai::Model>,
 32}
 33
 34pub struct OpenAiLanguageModelProvider {
 35    http_client: Arc<dyn HttpClient>,
 36    state: gpui::Model<State>,
 37}
 38
 39struct State {
 40    api_key: Option<String>,
 41    _subscription: Subscription,
 42}
 43
 44impl OpenAiLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 49                cx.notify();
 50            }),
 51        });
 52
 53        Self { http_client, state }
 54    }
 55}
 56
 57impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 58    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 59        Some(cx.observe(&self.state, |_, _, cx| {
 60            cx.notify();
 61        }))
 62    }
 63}
 64
 65impl LanguageModelProvider for OpenAiLanguageModelProvider {
 66    fn id(&self) -> LanguageModelProviderId {
 67        LanguageModelProviderId(PROVIDER_ID.into())
 68    }
 69
 70    fn name(&self) -> LanguageModelProviderName {
 71        LanguageModelProviderName(PROVIDER_NAME.into())
 72    }
 73
 74    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 75        let mut models = BTreeMap::default();
 76
 77        // Add base models from open_ai::Model::iter()
 78        for model in open_ai::Model::iter() {
 79            if !matches!(model, open_ai::Model::Custom { .. }) {
 80                models.insert(model.id().to_string(), model);
 81            }
 82        }
 83
 84        // Override with available models from settings
 85        for model in &AllLanguageModelSettings::get_global(cx)
 86            .openai
 87            .available_models
 88        {
 89            models.insert(model.id().to_string(), model.clone());
 90        }
 91
 92        models
 93            .into_values()
 94            .map(|model| {
 95                Arc::new(OpenAiLanguageModel {
 96                    id: LanguageModelId::from(model.id().to_string()),
 97                    model,
 98                    state: self.state.clone(),
 99                    http_client: self.http_client.clone(),
100                }) as Arc<dyn LanguageModel>
101            })
102            .collect()
103    }
104
105    fn is_authenticated(&self, cx: &AppContext) -> bool {
106        self.state.read(cx).api_key.is_some()
107    }
108
109    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
110        if self.is_authenticated(cx) {
111            Task::ready(Ok(()))
112        } else {
113            let api_url = AllLanguageModelSettings::get_global(cx)
114                .openai
115                .api_url
116                .clone();
117            let state = self.state.clone();
118            cx.spawn(|mut cx| async move {
119                let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
120                    api_key
121                } else {
122                    let (_, api_key) = cx
123                        .update(|cx| cx.read_credentials(&api_url))?
124                        .await?
125                        .ok_or_else(|| anyhow!("credentials not found"))?;
126                    String::from_utf8(api_key)?
127                };
128                state.update(&mut cx, |this, cx| {
129                    this.api_key = Some(api_key);
130                    cx.notify();
131                })
132            })
133        }
134    }
135
136    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
137        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
138            .into()
139    }
140
141    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
142        let settings = &AllLanguageModelSettings::get_global(cx).openai;
143        let delete_credentials = cx.delete_credentials(&settings.api_url);
144        let state = self.state.clone();
145        cx.spawn(|mut cx| async move {
146            delete_credentials.await.log_err();
147            state.update(&mut cx, |this, cx| {
148                this.api_key = None;
149                cx.notify();
150            })
151        })
152    }
153}
154
155pub struct OpenAiLanguageModel {
156    id: LanguageModelId,
157    model: open_ai::Model,
158    state: gpui::Model<State>,
159    http_client: Arc<dyn HttpClient>,
160}
161
162impl OpenAiLanguageModel {
163    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
164        Request {
165            model: self.model.clone(),
166            messages: request
167                .messages
168                .into_iter()
169                .map(|msg| match msg.role {
170                    Role::User => RequestMessage::User {
171                        content: msg.content,
172                    },
173                    Role::Assistant => RequestMessage::Assistant {
174                        content: Some(msg.content),
175                        tool_calls: Vec::new(),
176                    },
177                    Role::System => RequestMessage::System {
178                        content: msg.content,
179                    },
180                })
181                .collect(),
182            stream: true,
183            stop: request.stop,
184            temperature: request.temperature,
185            tools: Vec::new(),
186            tool_choice: None,
187        }
188    }
189}
190
191impl LanguageModel for OpenAiLanguageModel {
192    fn id(&self) -> LanguageModelId {
193        self.id.clone()
194    }
195
196    fn name(&self) -> LanguageModelName {
197        LanguageModelName::from(self.model.display_name().to_string())
198    }
199
200    fn provider_id(&self) -> LanguageModelProviderId {
201        LanguageModelProviderId(PROVIDER_ID.into())
202    }
203
204    fn provider_name(&self) -> LanguageModelProviderName {
205        LanguageModelProviderName(PROVIDER_NAME.into())
206    }
207
208    fn telemetry_id(&self) -> String {
209        format!("openai/{}", self.model.id())
210    }
211
212    fn max_token_count(&self) -> usize {
213        self.model.max_token_count()
214    }
215
216    fn count_tokens(
217        &self,
218        request: LanguageModelRequest,
219        cx: &AppContext,
220    ) -> BoxFuture<'static, Result<usize>> {
221        count_open_ai_tokens(request, self.model.clone(), cx)
222    }
223
224    fn stream_completion(
225        &self,
226        request: LanguageModelRequest,
227        cx: &AsyncAppContext,
228    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
229        let request = self.to_open_ai_request(request);
230
231        let http_client = self.http_client.clone();
232        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
233            let settings = &AllLanguageModelSettings::get_global(cx).openai;
234            (
235                state.api_key.clone(),
236                settings.api_url.clone(),
237                settings.low_speed_timeout,
238            )
239        }) else {
240            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
241        };
242
243        async move {
244            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
245            let request = stream_completion(
246                http_client.as_ref(),
247                &api_url,
248                &api_key,
249                request,
250                low_speed_timeout,
251            );
252            let response = request.await?;
253            let stream = response
254                .filter_map(|response| async move {
255                    match response {
256                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
257                        Err(error) => Some(Err(error)),
258                    }
259                })
260                .boxed();
261            Ok(stream)
262        }
263        .boxed()
264    }
265}
266
267pub fn count_open_ai_tokens(
268    request: LanguageModelRequest,
269    model: open_ai::Model,
270    cx: &AppContext,
271) -> BoxFuture<'static, Result<usize>> {
272    cx.background_executor()
273        .spawn(async move {
274            let messages = request
275                .messages
276                .into_iter()
277                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
278                    role: match message.role {
279                        Role::User => "user".into(),
280                        Role::Assistant => "assistant".into(),
281                        Role::System => "system".into(),
282                    },
283                    content: Some(message.content),
284                    name: None,
285                    function_call: None,
286                })
287                .collect::<Vec<_>>();
288
289            if let open_ai::Model::Custom { .. } = model {
290                tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
291            } else {
292                tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
293            }
294        })
295        .boxed()
296}
297
298struct AuthenticationPrompt {
299    api_key: View<Editor>,
300    state: gpui::Model<State>,
301}
302
303impl AuthenticationPrompt {
304    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
305        Self {
306            api_key: cx.new_view(|cx| {
307                let mut editor = Editor::single_line(cx);
308                editor.set_placeholder_text(
309                    "sk-000000000000000000000000000000000000000000000000",
310                    cx,
311                );
312                editor
313            }),
314            state,
315        }
316    }
317
318    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
319        let api_key = self.api_key.read(cx).text(cx);
320        if api_key.is_empty() {
321            return;
322        }
323
324        let settings = &AllLanguageModelSettings::get_global(cx).openai;
325        let write_credentials =
326            cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
327        let state = self.state.clone();
328        cx.spawn(|_, mut cx| async move {
329            write_credentials.await?;
330            state.update(&mut cx, |this, cx| {
331                this.api_key = Some(api_key);
332                cx.notify();
333            })
334        })
335        .detach_and_log_err(cx);
336    }
337
338    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
339        let settings = ThemeSettings::get_global(cx);
340        let text_style = TextStyle {
341            color: cx.theme().colors().text,
342            font_family: settings.ui_font.family.clone(),
343            font_features: settings.ui_font.features.clone(),
344            font_size: rems(0.875).into(),
345            font_weight: settings.ui_font.weight,
346            font_style: FontStyle::Normal,
347            line_height: relative(1.3),
348            background_color: None,
349            underline: None,
350            strikethrough: None,
351            white_space: WhiteSpace::Normal,
352        };
353        EditorElement::new(
354            &self.api_key,
355            EditorStyle {
356                background: cx.theme().colors().editor_background,
357                local_player: cx.theme().players().local(),
358                text: text_style,
359                ..Default::default()
360            },
361        )
362    }
363}
364
365impl Render for AuthenticationPrompt {
366    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
367        const INSTRUCTIONS: [&str; 6] = [
368            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
369            " - You can create an API key at: platform.openai.com/api-keys",
370            " - Make sure your OpenAI account has credits",
371            " - Having a subscription for another service like GitHub Copilot won't work.",
372            "",
373            "Paste your OpenAI API key below and hit enter to use the assistant:",
374        ];
375
376        v_flex()
377            .p_4()
378            .size_full()
379            .on_action(cx.listener(Self::save_api_key))
380            .children(
381                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
382            )
383            .child(
384                h_flex()
385                    .w_full()
386                    .my_2()
387                    .px_2()
388                    .py_1()
389                    .bg(cx.theme().colors().editor_background)
390                    .rounded_md()
391                    .child(self.render_api_key_editor(cx)),
392            )
393            .child(
394                Label::new(
395                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
396                )
397                .size(LabelSize::Small),
398            )
399            .child(
400                h_flex()
401                    .gap_2()
402                    .child(Label::new("Click on").size(LabelSize::Small))
403                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
404                    .child(
405                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
406                    ),
407            )
408            .into_any()
409    }
410}