open_ai.rs

  1use crate::assistant_settings::ZedDotDevModel;
  2use crate::{
  3    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  4};
  5use anyhow::{anyhow, Result};
  6use editor::{Editor, EditorElement, EditorStyle};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
  9use http::HttpClient;
 10use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 11use settings::Settings;
 12use std::time::Duration;
 13use std::{env, sync::Arc};
 14use theme::ThemeSettings;
 15use ui::prelude::*;
 16use util::ResultExt;
 17
 18pub struct OpenAiCompletionProvider {
 19    api_key: Option<String>,
 20    api_url: String,
 21    default_model: OpenAiModel,
 22    http_client: Arc<dyn HttpClient>,
 23    low_speed_timeout: Option<Duration>,
 24    settings_version: usize,
 25}
 26
 27impl OpenAiCompletionProvider {
 28    pub fn new(
 29        default_model: OpenAiModel,
 30        api_url: String,
 31        http_client: Arc<dyn HttpClient>,
 32        low_speed_timeout: Option<Duration>,
 33        settings_version: usize,
 34    ) -> Self {
 35        Self {
 36            api_key: None,
 37            api_url,
 38            default_model,
 39            http_client,
 40            low_speed_timeout,
 41            settings_version,
 42        }
 43    }
 44
 45    pub fn update(
 46        &mut self,
 47        default_model: OpenAiModel,
 48        api_url: String,
 49        low_speed_timeout: Option<Duration>,
 50        settings_version: usize,
 51    ) {
 52        self.default_model = default_model;
 53        self.api_url = api_url;
 54        self.low_speed_timeout = low_speed_timeout;
 55        self.settings_version = settings_version;
 56    }
 57
 58    pub fn settings_version(&self) -> usize {
 59        self.settings_version
 60    }
 61
 62    pub fn is_authenticated(&self) -> bool {
 63        self.api_key.is_some()
 64    }
 65
 66    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 67        if self.is_authenticated() {
 68            Task::ready(Ok(()))
 69        } else {
 70            let api_url = self.api_url.clone();
 71            cx.spawn(|mut cx| async move {
 72                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 73                    api_key
 74                } else {
 75                    let (_, api_key) = cx
 76                        .update(|cx| cx.read_credentials(&api_url))?
 77                        .await?
 78                        .ok_or_else(|| anyhow!("credentials not found"))?;
 79                    String::from_utf8(api_key)?
 80                };
 81                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 82                    if let CompletionProvider::OpenAi(provider) = provider {
 83                        provider.api_key = Some(api_key);
 84                    }
 85                })
 86            })
 87        }
 88    }
 89
 90    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 91        let delete_credentials = cx.delete_credentials(&self.api_url);
 92        cx.spawn(|mut cx| async move {
 93            delete_credentials.await.log_err();
 94            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 95                if let CompletionProvider::OpenAi(provider) = provider {
 96                    provider.api_key = None;
 97                }
 98            })
 99        })
100    }
101
102    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
103        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
104            .into()
105    }
106
107    pub fn default_model(&self) -> OpenAiModel {
108        self.default_model.clone()
109    }
110
111    pub fn count_tokens(
112        &self,
113        request: LanguageModelRequest,
114        cx: &AppContext,
115    ) -> BoxFuture<'static, Result<usize>> {
116        count_open_ai_tokens(request, cx.background_executor())
117    }
118
119    pub fn complete(
120        &self,
121        request: LanguageModelRequest,
122    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
123        let request = self.to_open_ai_request(request);
124
125        let http_client = self.http_client.clone();
126        let api_key = self.api_key.clone();
127        let api_url = self.api_url.clone();
128        let low_speed_timeout = self.low_speed_timeout;
129        async move {
130            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
131            let request = stream_completion(
132                http_client.as_ref(),
133                &api_url,
134                &api_key,
135                request,
136                low_speed_timeout,
137            );
138            let response = request.await?;
139            let stream = response
140                .filter_map(|response| async move {
141                    match response {
142                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
143                        Err(error) => Some(Err(error)),
144                    }
145                })
146                .boxed();
147            Ok(stream)
148        }
149        .boxed()
150    }
151
152    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
153        let model = match request.model {
154            LanguageModel::OpenAi(model) => model,
155            _ => self.default_model(),
156        };
157
158        Request {
159            model,
160            messages: request
161                .messages
162                .into_iter()
163                .map(|msg| match msg.role {
164                    Role::User => RequestMessage::User {
165                        content: msg.content,
166                    },
167                    Role::Assistant => RequestMessage::Assistant {
168                        content: Some(msg.content),
169                        tool_calls: Vec::new(),
170                    },
171                    Role::System => RequestMessage::System {
172                        content: msg.content,
173                    },
174                })
175                .collect(),
176            stream: true,
177            stop: request.stop,
178            temperature: request.temperature,
179            tools: Vec::new(),
180            tool_choice: None,
181        }
182    }
183}
184
185pub fn count_open_ai_tokens(
186    request: LanguageModelRequest,
187    background_executor: &gpui::BackgroundExecutor,
188) -> BoxFuture<'static, Result<usize>> {
189    background_executor
190        .spawn(async move {
191            let messages = request
192                .messages
193                .into_iter()
194                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
195                    role: match message.role {
196                        Role::User => "user".into(),
197                        Role::Assistant => "assistant".into(),
198                        Role::System => "system".into(),
199                    },
200                    content: Some(message.content),
201                    name: None,
202                    function_call: None,
203                })
204                .collect::<Vec<_>>();
205
206            match request.model {
207                LanguageModel::OpenAi(OpenAiModel::FourOmni)
208                | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Omni)
209                | LanguageModel::Anthropic(_)
210                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Opus)
211                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Sonnet)
212                | LanguageModel::ZedDotDev(ZedDotDevModel::Claude3Haiku) => {
213                    // Tiktoken doesn't yet support these models, so we manually use the
214                    // same tokenizer as GPT-4.
215                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
216                }
217                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
218            }
219        })
220        .boxed()
221}
222
223impl From<Role> for open_ai::Role {
224    fn from(val: Role) -> Self {
225        match val {
226            Role::User => OpenAiRole::User,
227            Role::Assistant => OpenAiRole::Assistant,
228            Role::System => OpenAiRole::System,
229        }
230    }
231}
232
233struct AuthenticationPrompt {
234    api_key: View<Editor>,
235    api_url: String,
236}
237
238impl AuthenticationPrompt {
239    fn new(api_url: String, cx: &mut WindowContext) -> Self {
240        Self {
241            api_key: cx.new_view(|cx| {
242                let mut editor = Editor::single_line(cx);
243                editor.set_placeholder_text(
244                    "sk-000000000000000000000000000000000000000000000000",
245                    cx,
246                );
247                editor
248            }),
249            api_url,
250        }
251    }
252
253    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
254        let api_key = self.api_key.read(cx).text(cx);
255        if api_key.is_empty() {
256            return;
257        }
258
259        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
260        cx.spawn(|_, mut cx| async move {
261            write_credentials.await?;
262            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
263                if let CompletionProvider::OpenAi(provider) = provider {
264                    provider.api_key = Some(api_key);
265                }
266            })
267        })
268        .detach_and_log_err(cx);
269    }
270
271    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
272        let settings = ThemeSettings::get_global(cx);
273        let text_style = TextStyle {
274            color: cx.theme().colors().text,
275            font_family: settings.ui_font.family.clone(),
276            font_features: settings.ui_font.features.clone(),
277            font_size: rems(0.875).into(),
278            font_weight: FontWeight::NORMAL,
279            font_style: FontStyle::Normal,
280            line_height: relative(1.3),
281            background_color: None,
282            underline: None,
283            strikethrough: None,
284            white_space: WhiteSpace::Normal,
285        };
286        EditorElement::new(
287            &self.api_key,
288            EditorStyle {
289                background: cx.theme().colors().editor_background,
290                local_player: cx.theme().players().local(),
291                text: text_style,
292                ..Default::default()
293            },
294        )
295    }
296}
297
298impl Render for AuthenticationPrompt {
299    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
300        const INSTRUCTIONS: [&str; 6] = [
301            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
302            " - You can create an API key at: platform.openai.com/api-keys",
303            " - Make sure your OpenAI account has credits",
304            " - Having a subscription for another service like GitHub Copilot won't work.",
305            "",
306            "Paste your OpenAI API key below and hit enter to use the assistant:",
307        ];
308
309        v_flex()
310            .p_4()
311            .size_full()
312            .on_action(cx.listener(Self::save_api_key))
313            .children(
314                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
315            )
316            .child(
317                h_flex()
318                    .w_full()
319                    .my_2()
320                    .px_2()
321                    .py_1()
322                    .bg(cx.theme().colors().editor_background)
323                    .rounded_md()
324                    .child(self.render_api_key_editor(cx)),
325            )
326            .child(
327                Label::new(
328                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
329                )
330                .size(LabelSize::Small),
331            )
332            .child(
333                h_flex()
334                    .gap_2()
335                    .child(Label::new("Click on").size(LabelSize::Small))
336                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
337                    .child(
338                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
339                    ),
340            )
341            .into_any()
342    }
343}