open_ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::HashMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{future::BoxFuture, FutureExt, StreamExt};
  5use gpui::{
  6    AnyView, AppContext, AsyncAppContext, FontStyle, Subscription, Task, TextStyle, View,
  7    WhiteSpace,
  8};
  9use http::HttpClient;
 10use open_ai::{stream_completion, Request, RequestMessage};
 11use settings::{Settings, SettingsStore};
 12use std::{sync::Arc, time::Duration};
 13use strum::IntoEnumIterator;
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18use crate::{
 19    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 20    LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
 21    LanguageModelRequest, Role,
 22};
 23
 24const PROVIDER_NAME: &str = "openai";
 25
 26#[derive(Default, Clone, Debug, PartialEq)]
 27pub struct OpenAiSettings {
 28    pub api_url: String,
 29    pub low_speed_timeout: Option<Duration>,
 30    pub available_models: Vec<open_ai::Model>,
 31}
 32
 33pub struct OpenAiLanguageModelProvider {
 34    http_client: Arc<dyn HttpClient>,
 35    state: gpui::Model<State>,
 36}
 37
 38struct State {
 39    api_key: Option<String>,
 40    settings: OpenAiSettings,
 41    _subscription: Subscription,
 42}
 43
 44impl OpenAiLanguageModelProvider {
 45    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 46        let state = cx.new_model(|cx| State {
 47            api_key: None,
 48            settings: OpenAiSettings::default(),
 49            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 50                this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
 51                cx.notify();
 52            }),
 53        });
 54
 55        Self { http_client, state }
 56    }
 57}
 58
 59impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 60    fn subscribe<T: 'static>(&self, cx: &mut gpui::ModelContext<T>) -> Option<gpui::Subscription> {
 61        Some(cx.observe(&self.state, |_, _, cx| {
 62            cx.notify();
 63        }))
 64    }
 65}
 66
 67impl LanguageModelProvider for OpenAiLanguageModelProvider {
 68    fn name(&self) -> LanguageModelProviderName {
 69        LanguageModelProviderName(PROVIDER_NAME.into())
 70    }
 71
 72    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
 73        let mut models = HashMap::default();
 74
 75        // Add base models from open_ai::Model::iter()
 76        for model in open_ai::Model::iter() {
 77            if !matches!(model, open_ai::Model::Custom { .. }) {
 78                models.insert(model.id().to_string(), model);
 79            }
 80        }
 81
 82        // Override with available models from settings
 83        for model in &self.state.read(cx).settings.available_models {
 84            models.insert(model.id().to_string(), model.clone());
 85        }
 86
 87        models
 88            .into_values()
 89            .map(|model| {
 90                Arc::new(OpenAiLanguageModel {
 91                    id: LanguageModelId::from(model.id().to_string()),
 92                    model,
 93                    state: self.state.clone(),
 94                    http_client: self.http_client.clone(),
 95                }) as Arc<dyn LanguageModel>
 96            })
 97            .collect()
 98    }
 99
100    fn is_authenticated(&self, cx: &AppContext) -> bool {
101        self.state.read(cx).api_key.is_some()
102    }
103
104    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
105        if self.is_authenticated(cx) {
106            Task::ready(Ok(()))
107        } else {
108            let api_url = self.state.read(cx).settings.api_url.clone();
109            let state = self.state.clone();
110            cx.spawn(|mut cx| async move {
111                let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
112                    api_key
113                } else {
114                    let (_, api_key) = cx
115                        .update(|cx| cx.read_credentials(&api_url))?
116                        .await?
117                        .ok_or_else(|| anyhow!("credentials not found"))?;
118                    String::from_utf8(api_key)?
119                };
120                state.update(&mut cx, |this, cx| {
121                    this.api_key = Some(api_key);
122                    cx.notify();
123                })
124            })
125        }
126    }
127
128    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
129        cx.new_view(|cx| AuthenticationPrompt::new(self.state.clone(), cx))
130            .into()
131    }
132
133    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
134        let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
135        let state = self.state.clone();
136        cx.spawn(|mut cx| async move {
137            delete_credentials.await.log_err();
138            state.update(&mut cx, |this, cx| {
139                this.api_key = None;
140                cx.notify();
141            })
142        })
143    }
144}
145
146pub struct OpenAiLanguageModel {
147    id: LanguageModelId,
148    model: open_ai::Model,
149    state: gpui::Model<State>,
150    http_client: Arc<dyn HttpClient>,
151}
152
153impl OpenAiLanguageModel {
154    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
155        Request {
156            model: self.model.clone(),
157            messages: request
158                .messages
159                .into_iter()
160                .map(|msg| match msg.role {
161                    Role::User => RequestMessage::User {
162                        content: msg.content,
163                    },
164                    Role::Assistant => RequestMessage::Assistant {
165                        content: Some(msg.content),
166                        tool_calls: Vec::new(),
167                    },
168                    Role::System => RequestMessage::System {
169                        content: msg.content,
170                    },
171                })
172                .collect(),
173            stream: true,
174            stop: request.stop,
175            temperature: request.temperature,
176            tools: Vec::new(),
177            tool_choice: None,
178        }
179    }
180}
181
182impl LanguageModel for OpenAiLanguageModel {
183    fn id(&self) -> LanguageModelId {
184        self.id.clone()
185    }
186
187    fn name(&self) -> LanguageModelName {
188        LanguageModelName::from(self.model.display_name().to_string())
189    }
190
191    fn provider_name(&self) -> LanguageModelProviderName {
192        LanguageModelProviderName(PROVIDER_NAME.into())
193    }
194
195    fn telemetry_id(&self) -> String {
196        format!("openai/{}", self.model.id())
197    }
198
199    fn max_token_count(&self) -> usize {
200        self.model.max_token_count()
201    }
202
203    fn count_tokens(
204        &self,
205        request: LanguageModelRequest,
206        cx: &AppContext,
207    ) -> BoxFuture<'static, Result<usize>> {
208        count_open_ai_tokens(request, self.model.clone(), cx)
209    }
210
211    fn stream_completion(
212        &self,
213        request: LanguageModelRequest,
214        cx: &AsyncAppContext,
215    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
216        let request = self.to_open_ai_request(request);
217
218        let http_client = self.http_client.clone();
219        let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
220            (
221                state.api_key.clone(),
222                state.settings.api_url.clone(),
223                state.settings.low_speed_timeout,
224            )
225        }) else {
226            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
227        };
228
229        async move {
230            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
231            let request = stream_completion(
232                http_client.as_ref(),
233                &api_url,
234                &api_key,
235                request,
236                low_speed_timeout,
237            );
238            let response = request.await?;
239            let stream = response
240                .filter_map(|response| async move {
241                    match response {
242                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
243                        Err(error) => Some(Err(error)),
244                    }
245                })
246                .boxed();
247            Ok(stream)
248        }
249        .boxed()
250    }
251}
252
253pub fn count_open_ai_tokens(
254    request: LanguageModelRequest,
255    model: open_ai::Model,
256    cx: &AppContext,
257) -> BoxFuture<'static, Result<usize>> {
258    cx.background_executor()
259        .spawn(async move {
260            let messages = request
261                .messages
262                .into_iter()
263                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
264                    role: match message.role {
265                        Role::User => "user".into(),
266                        Role::Assistant => "assistant".into(),
267                        Role::System => "system".into(),
268                    },
269                    content: Some(message.content),
270                    name: None,
271                    function_call: None,
272                })
273                .collect::<Vec<_>>();
274
275            if let open_ai::Model::Custom { .. } = model {
276                tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
277            } else {
278                tiktoken_rs::num_tokens_from_messages(model.id(), &messages)
279            }
280        })
281        .boxed()
282}
283
284struct AuthenticationPrompt {
285    api_key: View<Editor>,
286    state: gpui::Model<State>,
287}
288
289impl AuthenticationPrompt {
290    fn new(state: gpui::Model<State>, cx: &mut WindowContext) -> Self {
291        Self {
292            api_key: cx.new_view(|cx| {
293                let mut editor = Editor::single_line(cx);
294                editor.set_placeholder_text(
295                    "sk-000000000000000000000000000000000000000000000000",
296                    cx,
297                );
298                editor
299            }),
300            state,
301        }
302    }
303
304    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
305        let api_key = self.api_key.read(cx).text(cx);
306        if api_key.is_empty() {
307            return;
308        }
309
310        let write_credentials = cx.write_credentials(
311            &self.state.read(cx).settings.api_url,
312            "Bearer",
313            api_key.as_bytes(),
314        );
315        let state = self.state.clone();
316        cx.spawn(|_, mut cx| async move {
317            write_credentials.await?;
318            state.update(&mut cx, |this, cx| {
319                this.api_key = Some(api_key);
320                cx.notify();
321            })
322        })
323        .detach_and_log_err(cx);
324    }
325
326    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
327        let settings = ThemeSettings::get_global(cx);
328        let text_style = TextStyle {
329            color: cx.theme().colors().text,
330            font_family: settings.ui_font.family.clone(),
331            font_features: settings.ui_font.features.clone(),
332            font_size: rems(0.875).into(),
333            font_weight: settings.ui_font.weight,
334            font_style: FontStyle::Normal,
335            line_height: relative(1.3),
336            background_color: None,
337            underline: None,
338            strikethrough: None,
339            white_space: WhiteSpace::Normal,
340        };
341        EditorElement::new(
342            &self.api_key,
343            EditorStyle {
344                background: cx.theme().colors().editor_background,
345                local_player: cx.theme().players().local(),
346                text: text_style,
347                ..Default::default()
348            },
349        )
350    }
351}
352
353impl Render for AuthenticationPrompt {
354    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
355        const INSTRUCTIONS: [&str; 6] = [
356            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
357            " - You can create an API key at: platform.openai.com/api-keys",
358            " - Make sure your OpenAI account has credits",
359            " - Having a subscription for another service like GitHub Copilot won't work.",
360            "",
361            "Paste your OpenAI API key below and hit enter to use the assistant:",
362        ];
363
364        v_flex()
365            .p_4()
366            .size_full()
367            .on_action(cx.listener(Self::save_api_key))
368            .children(
369                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
370            )
371            .child(
372                h_flex()
373                    .w_full()
374                    .my_2()
375                    .px_2()
376                    .py_1()
377                    .bg(cx.theme().colors().editor_background)
378                    .rounded_md()
379                    .child(self.render_api_key_editor(cx)),
380            )
381            .child(
382                Label::new(
383                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
384                )
385                .size(LabelSize::Small),
386            )
387            .child(
388                h_flex()
389                    .gap_2()
390                    .child(Label::new("Click on").size(LabelSize::Small))
391                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
392                    .child(
393                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
394                    ),
395            )
396            .into_any()
397    }
398}