open_ai.rs

  1use crate::assistant_settings::ZedDotDevModel;
  2use crate::{
  3    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  4};
  5use anyhow::{anyhow, Result};
  6use editor::{Editor, EditorElement, EditorStyle};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
  9use http::HttpClient;
 10use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 11use settings::Settings;
 12use std::time::Duration;
 13use std::{env, sync::Arc};
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18pub struct OpenAiCompletionProvider {
 19    api_key: Option<String>,
 20    api_url: String,
 21    default_model: OpenAiModel,
 22    http_client: Arc<dyn HttpClient>,
 23    low_speed_timeout: Option<Duration>,
 24    settings_version: usize,
 25}
 26
 27impl OpenAiCompletionProvider {
 28    pub fn new(
 29        default_model: OpenAiModel,
 30        api_url: String,
 31        http_client: Arc<dyn HttpClient>,
 32        low_speed_timeout: Option<Duration>,
 33        settings_version: usize,
 34    ) -> Self {
 35        Self {
 36            api_key: None,
 37            api_url,
 38            default_model,
 39            http_client,
 40            low_speed_timeout,
 41            settings_version,
 42        }
 43    }
 44
 45    pub fn update(
 46        &mut self,
 47        default_model: OpenAiModel,
 48        api_url: String,
 49        low_speed_timeout: Option<Duration>,
 50        settings_version: usize,
 51    ) {
 52        self.default_model = default_model;
 53        self.api_url = api_url;
 54        self.low_speed_timeout = low_speed_timeout;
 55        self.settings_version = settings_version;
 56    }
 57
 58    pub fn settings_version(&self) -> usize {
 59        self.settings_version
 60    }
 61
 62    pub fn is_authenticated(&self) -> bool {
 63        self.api_key.is_some()
 64    }
 65
 66    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 67        if self.is_authenticated() {
 68            Task::ready(Ok(()))
 69        } else {
 70            let api_url = self.api_url.clone();
 71            cx.spawn(|mut cx| async move {
 72                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 73                    api_key
 74                } else {
 75                    let (_, api_key) = cx
 76                        .update(|cx| cx.read_credentials(&api_url))?
 77                        .await?
 78                        .ok_or_else(|| anyhow!("credentials not found"))?;
 79                    String::from_utf8(api_key)?
 80                };
 81                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 82                    if let CompletionProvider::OpenAi(provider) = provider {
 83                        provider.api_key = Some(api_key);
 84                    }
 85                })
 86            })
 87        }
 88    }
 89
 90    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 91        let delete_credentials = cx.delete_credentials(&self.api_url);
 92        cx.spawn(|mut cx| async move {
 93            delete_credentials.await.log_err();
 94            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 95                if let CompletionProvider::OpenAi(provider) = provider {
 96                    provider.api_key = None;
 97                }
 98            })
 99        })
100    }
101
102    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
103        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
104            .into()
105    }
106
107    pub fn default_model(&self) -> OpenAiModel {
108        self.default_model.clone()
109    }
110
111    pub fn count_tokens(
112        &self,
113        request: LanguageModelRequest,
114        cx: &AppContext,
115    ) -> BoxFuture<'static, Result<usize>> {
116        count_open_ai_tokens(request, cx.background_executor())
117    }
118
119    pub fn complete(
120        &self,
121        request: LanguageModelRequest,
122    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
123        let request = self.to_open_ai_request(request);
124
125        let http_client = self.http_client.clone();
126        let api_key = self.api_key.clone();
127        let api_url = self.api_url.clone();
128        let low_speed_timeout = self.low_speed_timeout;
129        async move {
130            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
131            let request = stream_completion(
132                http_client.as_ref(),
133                &api_url,
134                &api_key,
135                request,
136                low_speed_timeout,
137            );
138            let response = request.await?;
139            let stream = response
140                .filter_map(|response| async move {
141                    match response {
142                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
143                        Err(error) => Some(Err(error)),
144                    }
145                })
146                .boxed();
147            Ok(stream)
148        }
149        .boxed()
150    }
151
152    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
153        let model = match request.model {
154            LanguageModel::OpenAi(model) => model,
155            _ => self.default_model(),
156        };
157
158        Request {
159            model,
160            messages: request
161                .messages
162                .into_iter()
163                .map(|msg| match msg.role {
164                    Role::User => RequestMessage::User {
165                        content: msg.content,
166                    },
167                    Role::Assistant => RequestMessage::Assistant {
168                        content: Some(msg.content),
169                        tool_calls: Vec::new(),
170                    },
171                    Role::System => RequestMessage::System {
172                        content: msg.content,
173                    },
174                })
175                .collect(),
176            stream: true,
177            stop: request.stop,
178            temperature: request.temperature,
179            tools: Vec::new(),
180            tool_choice: None,
181        }
182    }
183}
184
185pub fn count_open_ai_tokens(
186    request: LanguageModelRequest,
187    background_executor: &gpui::BackgroundExecutor,
188) -> BoxFuture<'static, Result<usize>> {
189    background_executor
190        .spawn(async move {
191            let messages = request
192                .messages
193                .into_iter()
194                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
195                    role: match message.role {
196                        Role::User => "user".into(),
197                        Role::Assistant => "assistant".into(),
198                        Role::System => "system".into(),
199                    },
200                    content: Some(message.content),
201                    name: None,
202                    function_call: None,
203                })
204                .collect::<Vec<_>>();
205
206            match request.model {
207                LanguageModel::Anthropic(_)
208                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
209                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
210                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
211                    // Tiktoken doesn't yet support these models, so we manually use the
212                    // same tokenizer as GPT-4.
213                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
214                }
215                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
216            }
217        })
218        .boxed()
219}
220
221impl From<Role> for open_ai::Role {
222    fn from(val: Role) -> Self {
223        match val {
224            Role::User => OpenAiRole::User,
225            Role::Assistant => OpenAiRole::Assistant,
226            Role::System => OpenAiRole::System,
227        }
228    }
229}
230
231struct AuthenticationPrompt {
232    api_key: View<Editor>,
233    api_url: String,
234}
235
236impl AuthenticationPrompt {
237    fn new(api_url: String, cx: &mut WindowContext) -> Self {
238        Self {
239            api_key: cx.new_view(|cx| {
240                let mut editor = Editor::single_line(cx);
241                editor.set_placeholder_text(
242                    "sk-000000000000000000000000000000000000000000000000",
243                    cx,
244                );
245                editor
246            }),
247            api_url,
248        }
249    }
250
251    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
252        let api_key = self.api_key.read(cx).text(cx);
253        if api_key.is_empty() {
254            return;
255        }
256
257        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
258        cx.spawn(|_, mut cx| async move {
259            write_credentials.await?;
260            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
261                if let CompletionProvider::OpenAi(provider) = provider {
262                    provider.api_key = Some(api_key);
263                }
264            })
265        })
266        .detach_and_log_err(cx);
267    }
268
269    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
270        let settings = ThemeSettings::get_global(cx);
271        let text_style = TextStyle {
272            color: cx.theme().colors().text,
273            font_family: settings.ui_font.family.clone(),
274            font_features: settings.ui_font.features.clone(),
275            font_size: rems(0.875).into(),
276            font_weight: settings.ui_font.weight,
277            font_style: FontStyle::Normal,
278            line_height: relative(1.3),
279            background_color: None,
280            underline: None,
281            strikethrough: None,
282            white_space: WhiteSpace::Normal,
283        };
284        EditorElement::new(
285            &self.api_key,
286            EditorStyle {
287                background: cx.theme().colors().editor_background,
288                local_player: cx.theme().players().local(),
289                text: text_style,
290                ..Default::default()
291            },
292        )
293    }
294}
295
296impl Render for AuthenticationPrompt {
297    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
298        const INSTRUCTIONS: [&str; 6] = [
299            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
300            " - You can create an API key at: platform.openai.com/api-keys",
301            " - Make sure your OpenAI account has credits",
302            " - Having a subscription for another service like GitHub Copilot won't work.",
303            "",
304            "Paste your OpenAI API key below and hit enter to use the assistant:",
305        ];
306
307        v_flex()
308            .p_4()
309            .size_full()
310            .on_action(cx.listener(Self::save_api_key))
311            .children(
312                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
313            )
314            .child(
315                h_flex()
316                    .w_full()
317                    .my_2()
318                    .px_2()
319                    .py_1()
320                    .bg(cx.theme().colors().editor_background)
321                    .rounded_md()
322                    .child(self.render_api_key_editor(cx)),
323            )
324            .child(
325                Label::new(
326                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
327                )
328                .size(LabelSize::Small),
329            )
330            .child(
331                h_flex()
332                    .gap_2()
333                    .child(Label::new("Click on").size(LabelSize::Small))
334                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
335                    .child(
336                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
337                    ),
338            )
339            .into_any()
340    }
341}