open_ai.rs

  1use crate::assistant_settings::CloudModel;
  2use crate::assistant_settings::{AssistantProvider, AssistantSettings};
  3use crate::LanguageModelCompletionProvider;
  4use crate::{
  5    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  6};
  7use anyhow::{anyhow, Result};
  8use editor::{Editor, EditorElement, EditorStyle};
  9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
 10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
 11use http::HttpClient;
 12use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 13use settings::Settings;
 14use std::time::Duration;
 15use std::{env, sync::Arc};
 16use strum::IntoEnumIterator;
 17use theme::ThemeSettings;
 18use ui::prelude::*;
 19use util::ResultExt;
 20
 21pub struct OpenAiCompletionProvider {
 22    api_key: Option<String>,
 23    api_url: String,
 24    model: OpenAiModel,
 25    http_client: Arc<dyn HttpClient>,
 26    low_speed_timeout: Option<Duration>,
 27    settings_version: usize,
 28}
 29
 30impl OpenAiCompletionProvider {
 31    pub fn new(
 32        model: OpenAiModel,
 33        api_url: String,
 34        http_client: Arc<dyn HttpClient>,
 35        low_speed_timeout: Option<Duration>,
 36        settings_version: usize,
 37    ) -> Self {
 38        Self {
 39            api_key: None,
 40            api_url,
 41            model,
 42            http_client,
 43            low_speed_timeout,
 44            settings_version,
 45        }
 46    }
 47
 48    pub fn update(
 49        &mut self,
 50        model: OpenAiModel,
 51        api_url: String,
 52        low_speed_timeout: Option<Duration>,
 53        settings_version: usize,
 54    ) {
 55        self.model = model;
 56        self.api_url = api_url;
 57        self.low_speed_timeout = low_speed_timeout;
 58        self.settings_version = settings_version;
 59    }
 60
 61    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
 62        let model = match request.model {
 63            LanguageModel::OpenAi(model) => model,
 64            _ => self.model.clone(),
 65        };
 66
 67        Request {
 68            model,
 69            messages: request
 70                .messages
 71                .into_iter()
 72                .map(|msg| match msg.role {
 73                    Role::User => RequestMessage::User {
 74                        content: msg.content,
 75                    },
 76                    Role::Assistant => RequestMessage::Assistant {
 77                        content: Some(msg.content),
 78                        tool_calls: Vec::new(),
 79                    },
 80                    Role::System => RequestMessage::System {
 81                        content: msg.content,
 82                    },
 83                })
 84                .collect(),
 85            stream: true,
 86            stop: request.stop,
 87            temperature: request.temperature,
 88            tools: Vec::new(),
 89            tool_choice: None,
 90        }
 91    }
 92}
 93
 94impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
 95    fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
 96        if let AssistantProvider::OpenAi {
 97            available_models, ..
 98        } = &AssistantSettings::get_global(cx).provider
 99        {
100            if !available_models.is_empty() {
101                return available_models
102                    .iter()
103                    .cloned()
104                    .map(LanguageModel::OpenAi)
105                    .collect();
106            }
107        }
108        let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
109            vec![self.model.clone()]
110        } else {
111            OpenAiModel::iter()
112                .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
113                .collect()
114        };
115        available_models
116            .into_iter()
117            .map(LanguageModel::OpenAi)
118            .collect()
119    }
120
121    fn settings_version(&self) -> usize {
122        self.settings_version
123    }
124
125    fn is_authenticated(&self) -> bool {
126        self.api_key.is_some()
127    }
128
129    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
130        if self.is_authenticated() {
131            Task::ready(Ok(()))
132        } else {
133            let api_url = self.api_url.clone();
134            cx.spawn(|mut cx| async move {
135                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
136                    api_key
137                } else {
138                    let (_, api_key) = cx
139                        .update(|cx| cx.read_credentials(&api_url))?
140                        .await?
141                        .ok_or_else(|| anyhow!("credentials not found"))?;
142                    String::from_utf8(api_key)?
143                };
144                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
145                    provider.update_current_as::<_, Self>(|provider| {
146                        provider.api_key = Some(api_key);
147                    });
148                })
149            })
150        }
151    }
152
153    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
154        let delete_credentials = cx.delete_credentials(&self.api_url);
155        cx.spawn(|mut cx| async move {
156            delete_credentials.await.log_err();
157            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
158                provider.update_current_as::<_, Self>(|provider| {
159                    provider.api_key = None;
160                });
161            })
162        })
163    }
164
165    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
166        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
167            .into()
168    }
169
170    fn model(&self) -> LanguageModel {
171        LanguageModel::OpenAi(self.model.clone())
172    }
173
174    fn count_tokens(
175        &self,
176        request: LanguageModelRequest,
177        cx: &AppContext,
178    ) -> BoxFuture<'static, Result<usize>> {
179        count_open_ai_tokens(request, cx.background_executor())
180    }
181
182    fn complete(
183        &self,
184        request: LanguageModelRequest,
185    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
186        let request = self.to_open_ai_request(request);
187
188        let http_client = self.http_client.clone();
189        let api_key = self.api_key.clone();
190        let api_url = self.api_url.clone();
191        let low_speed_timeout = self.low_speed_timeout;
192        async move {
193            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
194            let request = stream_completion(
195                http_client.as_ref(),
196                &api_url,
197                &api_key,
198                request,
199                low_speed_timeout,
200            );
201            let response = request.await?;
202            let stream = response
203                .filter_map(|response| async move {
204                    match response {
205                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
206                        Err(error) => Some(Err(error)),
207                    }
208                })
209                .boxed();
210            Ok(stream)
211        }
212        .boxed()
213    }
214
215    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
216        self
217    }
218}
219
220pub fn count_open_ai_tokens(
221    request: LanguageModelRequest,
222    background_executor: &gpui::BackgroundExecutor,
223) -> BoxFuture<'static, Result<usize>> {
224    background_executor
225        .spawn(async move {
226            let messages = request
227                .messages
228                .into_iter()
229                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
230                    role: match message.role {
231                        Role::User => "user".into(),
232                        Role::Assistant => "assistant".into(),
233                        Role::System => "system".into(),
234                    },
235                    content: Some(message.content),
236                    name: None,
237                    function_call: None,
238                })
239                .collect::<Vec<_>>();
240
241            match request.model {
242                LanguageModel::Anthropic(_)
243                | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
244                | LanguageModel::Cloud(CloudModel::Claude3Opus)
245                | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
246                | LanguageModel::Cloud(CloudModel::Claude3Haiku)
247                | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
248                    // Tiktoken doesn't yet support these models, so we manually use the
249                    // same tokenizer as GPT-4.
250                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
251                }
252                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
253            }
254        })
255        .boxed()
256}
257
258impl From<Role> for open_ai::Role {
259    fn from(val: Role) -> Self {
260        match val {
261            Role::User => OpenAiRole::User,
262            Role::Assistant => OpenAiRole::Assistant,
263            Role::System => OpenAiRole::System,
264        }
265    }
266}
267
268struct AuthenticationPrompt {
269    api_key: View<Editor>,
270    api_url: String,
271}
272
273impl AuthenticationPrompt {
274    fn new(api_url: String, cx: &mut WindowContext) -> Self {
275        Self {
276            api_key: cx.new_view(|cx| {
277                let mut editor = Editor::single_line(cx);
278                editor.set_placeholder_text(
279                    "sk-000000000000000000000000000000000000000000000000",
280                    cx,
281                );
282                editor
283            }),
284            api_url,
285        }
286    }
287
288    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
289        let api_key = self.api_key.read(cx).text(cx);
290        if api_key.is_empty() {
291            return;
292        }
293
294        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
295        cx.spawn(|_, mut cx| async move {
296            write_credentials.await?;
297            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
298                provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
299                    provider.api_key = Some(api_key);
300                });
301            })
302        })
303        .detach_and_log_err(cx);
304    }
305
306    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
307        let settings = ThemeSettings::get_global(cx);
308        let text_style = TextStyle {
309            color: cx.theme().colors().text,
310            font_family: settings.ui_font.family.clone(),
311            font_features: settings.ui_font.features.clone(),
312            font_size: rems(0.875).into(),
313            font_weight: settings.ui_font.weight,
314            font_style: FontStyle::Normal,
315            line_height: relative(1.3),
316            background_color: None,
317            underline: None,
318            strikethrough: None,
319            white_space: WhiteSpace::Normal,
320        };
321        EditorElement::new(
322            &self.api_key,
323            EditorStyle {
324                background: cx.theme().colors().editor_background,
325                local_player: cx.theme().players().local(),
326                text: text_style,
327                ..Default::default()
328            },
329        )
330    }
331}
332
333impl Render for AuthenticationPrompt {
334    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
335        const INSTRUCTIONS: [&str; 6] = [
336            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
337            " - You can create an API key at: platform.openai.com/api-keys",
338            " - Make sure your OpenAI account has credits",
339            " - Having a subscription for another service like GitHub Copilot won't work.",
340            "",
341            "Paste your OpenAI API key below and hit enter to use the assistant:",
342        ];
343
344        v_flex()
345            .p_4()
346            .size_full()
347            .on_action(cx.listener(Self::save_api_key))
348            .children(
349                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
350            )
351            .child(
352                h_flex()
353                    .w_full()
354                    .my_2()
355                    .px_2()
356                    .py_1()
357                    .bg(cx.theme().colors().editor_background)
358                    .rounded_md()
359                    .child(self.render_api_key_editor(cx)),
360            )
361            .child(
362                Label::new(
363                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
364                )
365                .size(LabelSize::Small),
366            )
367            .child(
368                h_flex()
369                    .gap_2()
370                    .child(Label::new("Click on").size(LabelSize::Small))
371                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
372                    .child(
373                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
374                    ),
375            )
376            .into_any()
377    }
378}