open_ai.rs

  1use crate::{
  2    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::{anyhow, Result};
  5use editor::{Editor, EditorElement, EditorStyle};
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
  8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
  9use settings::Settings;
 10use std::{env, sync::Arc};
 11use theme::ThemeSettings;
 12use ui::prelude::*;
 13use util::{http::HttpClient, ResultExt};
 14
 15pub struct OpenAiCompletionProvider {
 16    api_key: Option<String>,
 17    api_url: String,
 18    default_model: OpenAiModel,
 19    http_client: Arc<dyn HttpClient>,
 20    settings_version: usize,
 21}
 22
 23impl OpenAiCompletionProvider {
 24    pub fn new(
 25        default_model: OpenAiModel,
 26        api_url: String,
 27        http_client: Arc<dyn HttpClient>,
 28        settings_version: usize,
 29    ) -> Self {
 30        Self {
 31            api_key: None,
 32            api_url,
 33            default_model,
 34            http_client,
 35            settings_version,
 36        }
 37    }
 38
 39    pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
 40        self.default_model = default_model;
 41        self.api_url = api_url;
 42        self.settings_version = settings_version;
 43    }
 44
 45    pub fn settings_version(&self) -> usize {
 46        self.settings_version
 47    }
 48
 49    pub fn is_authenticated(&self) -> bool {
 50        self.api_key.is_some()
 51    }
 52
 53    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 54        if self.is_authenticated() {
 55            Task::ready(Ok(()))
 56        } else {
 57            let api_url = self.api_url.clone();
 58            cx.spawn(|mut cx| async move {
 59                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 60                    api_key
 61                } else {
 62                    let (_, api_key) = cx
 63                        .update(|cx| cx.read_credentials(&api_url))?
 64                        .await?
 65                        .ok_or_else(|| anyhow!("credentials not found"))?;
 66                    String::from_utf8(api_key)?
 67                };
 68                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 69                    if let CompletionProvider::OpenAi(provider) = provider {
 70                        provider.api_key = Some(api_key);
 71                    }
 72                })
 73            })
 74        }
 75    }
 76
 77    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 78        let delete_credentials = cx.delete_credentials(&self.api_url);
 79        cx.spawn(|mut cx| async move {
 80            delete_credentials.await.log_err();
 81            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 82                if let CompletionProvider::OpenAi(provider) = provider {
 83                    provider.api_key = None;
 84                }
 85            })
 86        })
 87    }
 88
 89    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 90        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
 91            .into()
 92    }
 93
 94    pub fn default_model(&self) -> OpenAiModel {
 95        self.default_model.clone()
 96    }
 97
 98    pub fn count_tokens(
 99        &self,
100        request: LanguageModelRequest,
101        cx: &AppContext,
102    ) -> BoxFuture<'static, Result<usize>> {
103        count_open_ai_tokens(request, cx.background_executor())
104    }
105
106    pub fn complete(
107        &self,
108        request: LanguageModelRequest,
109    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
110        let request = self.to_open_ai_request(request);
111
112        let http_client = self.http_client.clone();
113        let api_key = self.api_key.clone();
114        let api_url = self.api_url.clone();
115        async move {
116            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
117            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
118            let response = request.await?;
119            let stream = response
120                .filter_map(|response| async move {
121                    match response {
122                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
123                        Err(error) => Some(Err(error)),
124                    }
125                })
126                .boxed();
127            Ok(stream)
128        }
129        .boxed()
130    }
131
132    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
133        let model = match request.model {
134            LanguageModel::ZedDotDev(_) => self.default_model(),
135            LanguageModel::OpenAi(model) => model,
136        };
137
138        Request {
139            model,
140            messages: request
141                .messages
142                .into_iter()
143                .map(|msg| match msg.role {
144                    Role::User => RequestMessage::User {
145                        content: msg.content,
146                    },
147                    Role::Assistant => RequestMessage::Assistant {
148                        content: Some(msg.content),
149                        tool_calls: Vec::new(),
150                    },
151                    Role::System => RequestMessage::System {
152                        content: msg.content,
153                    },
154                })
155                .collect(),
156            stream: true,
157            stop: request.stop,
158            temperature: request.temperature,
159            tools: Vec::new(),
160            tool_choice: None,
161        }
162    }
163}
164
165pub fn count_open_ai_tokens(
166    request: LanguageModelRequest,
167    background_executor: &gpui::BackgroundExecutor,
168) -> BoxFuture<'static, Result<usize>> {
169    background_executor
170        .spawn(async move {
171            let messages = request
172                .messages
173                .into_iter()
174                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
175                    role: match message.role {
176                        Role::User => "user".into(),
177                        Role::Assistant => "assistant".into(),
178                        Role::System => "system".into(),
179                    },
180                    content: Some(message.content),
181                    name: None,
182                    function_call: None,
183                })
184                .collect::<Vec<_>>();
185
186            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
187        })
188        .boxed()
189}
190
191impl From<Role> for open_ai::Role {
192    fn from(val: Role) -> Self {
193        match val {
194            Role::User => OpenAiRole::User,
195            Role::Assistant => OpenAiRole::Assistant,
196            Role::System => OpenAiRole::System,
197        }
198    }
199}
200
201struct AuthenticationPrompt {
202    api_key: View<Editor>,
203    api_url: String,
204}
205
206impl AuthenticationPrompt {
207    fn new(api_url: String, cx: &mut WindowContext) -> Self {
208        Self {
209            api_key: cx.new_view(|cx| {
210                let mut editor = Editor::single_line(cx);
211                editor.set_placeholder_text(
212                    "sk-000000000000000000000000000000000000000000000000",
213                    cx,
214                );
215                editor
216            }),
217            api_url,
218        }
219    }
220
221    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
222        let api_key = self.api_key.read(cx).text(cx);
223        if api_key.is_empty() {
224            return;
225        }
226
227        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
228        cx.spawn(|_, mut cx| async move {
229            write_credentials.await?;
230            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
231                if let CompletionProvider::OpenAi(provider) = provider {
232                    provider.api_key = Some(api_key);
233                }
234            })
235        })
236        .detach_and_log_err(cx);
237    }
238
239    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
240        let settings = ThemeSettings::get_global(cx);
241        let text_style = TextStyle {
242            color: cx.theme().colors().text,
243            font_family: settings.ui_font.family.clone(),
244            font_features: settings.ui_font.features,
245            font_size: rems(0.875).into(),
246            font_weight: FontWeight::NORMAL,
247            font_style: FontStyle::Normal,
248            line_height: relative(1.3),
249            background_color: None,
250            underline: None,
251            strikethrough: None,
252            white_space: WhiteSpace::Normal,
253        };
254        EditorElement::new(
255            &self.api_key,
256            EditorStyle {
257                background: cx.theme().colors().editor_background,
258                local_player: cx.theme().players().local(),
259                text: text_style,
260                ..Default::default()
261            },
262        )
263    }
264}
265
266impl Render for AuthenticationPrompt {
267    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
268        const INSTRUCTIONS: [&str; 6] = [
269            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
270            " - You can create an API key at: platform.openai.com/api-keys",
271            " - Make sure your OpenAI account has credits",
272            " - Having a subscription for another service like GitHub Copilot won't work.",
273            "",
274            "Paste your OpenAI API key below and hit enter to use the assistant:",
275        ];
276
277        v_flex()
278            .p_4()
279            .size_full()
280            .on_action(cx.listener(Self::save_api_key))
281            .children(
282                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
283            )
284            .child(
285                h_flex()
286                    .w_full()
287                    .my_2()
288                    .px_2()
289                    .py_1()
290                    .bg(cx.theme().colors().editor_background)
291                    .rounded_md()
292                    .child(self.render_api_key_editor(cx)),
293            )
294            .child(
295                Label::new(
296                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
297                )
298                .size(LabelSize::Small),
299            )
300            .child(
301                h_flex()
302                    .gap_2()
303                    .child(Label::new("Click on").size(LabelSize::Small))
304                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
305                    .child(
306                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
307                    ),
308            )
309            .into_any()
310    }
311}