open_ai.rs

  1use crate::{
  2    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::{anyhow, Result};
  5use editor::{Editor, EditorElement, EditorStyle};
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
  8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
  9use settings::Settings;
 10use std::{env, sync::Arc};
 11use theme::ThemeSettings;
 12use ui::prelude::*;
 13use util::{http::HttpClient, ResultExt};
 14
 15pub struct OpenAiCompletionProvider {
 16    api_key: Option<String>,
 17    api_url: String,
 18    default_model: OpenAiModel,
 19    http_client: Arc<dyn HttpClient>,
 20    settings_version: usize,
 21}
 22
 23impl OpenAiCompletionProvider {
 24    pub fn new(
 25        default_model: OpenAiModel,
 26        api_url: String,
 27        http_client: Arc<dyn HttpClient>,
 28        settings_version: usize,
 29    ) -> Self {
 30        Self {
 31            api_key: None,
 32            api_url,
 33            default_model,
 34            http_client,
 35            settings_version,
 36        }
 37    }
 38
 39    pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
 40        self.default_model = default_model;
 41        self.api_url = api_url;
 42        self.settings_version = settings_version;
 43    }
 44
 45    pub fn settings_version(&self) -> usize {
 46        self.settings_version
 47    }
 48
 49    pub fn is_authenticated(&self) -> bool {
 50        self.api_key.is_some()
 51    }
 52
 53    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 54        if self.is_authenticated() {
 55            Task::ready(Ok(()))
 56        } else {
 57            let api_url = self.api_url.clone();
 58            cx.spawn(|mut cx| async move {
 59                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 60                    api_key
 61                } else {
 62                    let (_, api_key) = cx
 63                        .update(|cx| cx.read_credentials(&api_url))?
 64                        .await?
 65                        .ok_or_else(|| anyhow!("credentials not found"))?;
 66                    String::from_utf8(api_key)?
 67                };
 68                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 69                    if let CompletionProvider::OpenAi(provider) = provider {
 70                        provider.api_key = Some(api_key);
 71                    }
 72                })
 73            })
 74        }
 75    }
 76
 77    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 78        let delete_credentials = cx.delete_credentials(&self.api_url);
 79        cx.spawn(|mut cx| async move {
 80            delete_credentials.await.log_err();
 81            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 82                if let CompletionProvider::OpenAi(provider) = provider {
 83                    provider.api_key = None;
 84                }
 85            })
 86        })
 87    }
 88
 89    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 90        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
 91            .into()
 92    }
 93
 94    pub fn default_model(&self) -> OpenAiModel {
 95        self.default_model.clone()
 96    }
 97
 98    pub fn count_tokens(
 99        &self,
100        request: LanguageModelRequest,
101        cx: &AppContext,
102    ) -> BoxFuture<'static, Result<usize>> {
103        count_open_ai_tokens(request, cx.background_executor())
104    }
105
106    pub fn complete(
107        &self,
108        request: LanguageModelRequest,
109    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
110        let request = self.to_open_ai_request(request);
111
112        let http_client = self.http_client.clone();
113        let api_key = self.api_key.clone();
114        let api_url = self.api_url.clone();
115        async move {
116            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
117            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
118            let response = request.await?;
119            let stream = response
120                .filter_map(|response| async move {
121                    match response {
122                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
123                        Err(error) => Some(Err(error)),
124                    }
125                })
126                .boxed();
127            Ok(stream)
128        }
129        .boxed()
130    }
131
132    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
133        let model = match request.model {
134            LanguageModel::ZedDotDev(_) => self.default_model(),
135            LanguageModel::OpenAi(model) => model,
136        };
137
138        Request {
139            model,
140            messages: request
141                .messages
142                .into_iter()
143                .map(|msg| RequestMessage {
144                    role: msg.role.into(),
145                    content: msg.content,
146                })
147                .collect(),
148            stream: true,
149            stop: request.stop,
150            temperature: request.temperature,
151        }
152    }
153}
154
155pub fn count_open_ai_tokens(
156    request: LanguageModelRequest,
157    background_executor: &gpui::BackgroundExecutor,
158) -> BoxFuture<'static, Result<usize>> {
159    background_executor
160        .spawn(async move {
161            let messages = request
162                .messages
163                .into_iter()
164                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
165                    role: match message.role {
166                        Role::User => "user".into(),
167                        Role::Assistant => "assistant".into(),
168                        Role::System => "system".into(),
169                    },
170                    content: Some(message.content),
171                    name: None,
172                    function_call: None,
173                })
174                .collect::<Vec<_>>();
175
176            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
177        })
178        .boxed()
179}
180
181impl From<Role> for open_ai::Role {
182    fn from(val: Role) -> Self {
183        match val {
184            Role::User => OpenAiRole::User,
185            Role::Assistant => OpenAiRole::Assistant,
186            Role::System => OpenAiRole::System,
187        }
188    }
189}
190
191struct AuthenticationPrompt {
192    api_key: View<Editor>,
193    api_url: String,
194}
195
196impl AuthenticationPrompt {
197    fn new(api_url: String, cx: &mut WindowContext) -> Self {
198        Self {
199            api_key: cx.new_view(|cx| {
200                let mut editor = Editor::single_line(cx);
201                editor.set_placeholder_text(
202                    "sk-000000000000000000000000000000000000000000000000",
203                    cx,
204                );
205                editor
206            }),
207            api_url,
208        }
209    }
210
211    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
212        let api_key = self.api_key.read(cx).text(cx);
213        if api_key.is_empty() {
214            return;
215        }
216
217        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
218        cx.spawn(|_, mut cx| async move {
219            write_credentials.await?;
220            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
221                if let CompletionProvider::OpenAi(provider) = provider {
222                    provider.api_key = Some(api_key);
223                }
224            })
225        })
226        .detach_and_log_err(cx);
227    }
228
229    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
230        let settings = ThemeSettings::get_global(cx);
231        let text_style = TextStyle {
232            color: cx.theme().colors().text,
233            font_family: settings.ui_font.family.clone(),
234            font_features: settings.ui_font.features,
235            font_size: rems(0.875).into(),
236            font_weight: FontWeight::NORMAL,
237            font_style: FontStyle::Normal,
238            line_height: relative(1.3),
239            background_color: None,
240            underline: None,
241            strikethrough: None,
242            white_space: WhiteSpace::Normal,
243        };
244        EditorElement::new(
245            &self.api_key,
246            EditorStyle {
247                background: cx.theme().colors().editor_background,
248                local_player: cx.theme().players().local(),
249                text: text_style,
250                ..Default::default()
251            },
252        )
253    }
254}
255
256impl Render for AuthenticationPrompt {
257    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
258        const INSTRUCTIONS: [&str; 6] = [
259            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
260            " - You can create an API key at: platform.openai.com/api-keys",
261            " - Make sure your OpenAI account has credits",
262            " - Having a subscription for another service like GitHub Copilot won't work.",
263            "",
264            "Paste your OpenAI API key below and hit enter to use the assistant:",
265        ];
266
267        v_flex()
268            .p_4()
269            .size_full()
270            .on_action(cx.listener(Self::save_api_key))
271            .children(
272                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
273            )
274            .child(
275                h_flex()
276                    .w_full()
277                    .my_2()
278                    .px_2()
279                    .py_1()
280                    .bg(cx.theme().colors().editor_background)
281                    .rounded_md()
282                    .child(self.render_api_key_editor(cx)),
283            )
284            .child(
285                Label::new(
286                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
287                )
288                .size(LabelSize::Small),
289            )
290            .child(
291                h_flex()
292                    .gap_2()
293                    .child(Label::new("Click on").size(LabelSize::Small))
294                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
295                    .child(
296                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
297                    ),
298            )
299            .into_any()
300    }
301}