open_ai.rs

  1use crate::assistant_settings::CloudModel;
  2use crate::assistant_settings::{AssistantProvider, AssistantSettings};
  3use crate::{
  4    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  5};
  6use anyhow::{anyhow, Result};
  7use editor::{Editor, EditorElement, EditorStyle};
  8use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  9use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
 10use http::HttpClient;
 11use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 12use settings::Settings;
 13use std::time::Duration;
 14use std::{env, sync::Arc};
 15use strum::IntoEnumIterator;
 16use theme::ThemeSettings;
 17use ui::prelude::*;
 18use util::ResultExt;
 19
 20pub struct OpenAiCompletionProvider {
 21    api_key: Option<String>,
 22    api_url: String,
 23    model: OpenAiModel,
 24    http_client: Arc<dyn HttpClient>,
 25    low_speed_timeout: Option<Duration>,
 26    settings_version: usize,
 27}
 28
 29impl OpenAiCompletionProvider {
 30    pub fn new(
 31        model: OpenAiModel,
 32        api_url: String,
 33        http_client: Arc<dyn HttpClient>,
 34        low_speed_timeout: Option<Duration>,
 35        settings_version: usize,
 36    ) -> Self {
 37        Self {
 38            api_key: None,
 39            api_url,
 40            model,
 41            http_client,
 42            low_speed_timeout,
 43            settings_version,
 44        }
 45    }
 46
 47    pub fn update(
 48        &mut self,
 49        model: OpenAiModel,
 50        api_url: String,
 51        low_speed_timeout: Option<Duration>,
 52        settings_version: usize,
 53    ) {
 54        self.model = model;
 55        self.api_url = api_url;
 56        self.low_speed_timeout = low_speed_timeout;
 57        self.settings_version = settings_version;
 58    }
 59
 60    pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
 61        if let AssistantProvider::OpenAi {
 62            available_models, ..
 63        } = &AssistantSettings::get_global(cx).provider
 64        {
 65            if !available_models.is_empty() {
 66                // available_models is set, just return it
 67                return available_models.clone().into_iter();
 68            }
 69        }
 70        let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
 71            // available_models is not set but the default model is set to custom, only show custom
 72            vec![self.model.clone()]
 73        } else {
 74            // default case, use all models except custom
 75            OpenAiModel::iter()
 76                .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
 77                .collect()
 78        };
 79        available_models.into_iter()
 80    }
 81
 82    pub fn settings_version(&self) -> usize {
 83        self.settings_version
 84    }
 85
 86    pub fn is_authenticated(&self) -> bool {
 87        self.api_key.is_some()
 88    }
 89
 90    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 91        if self.is_authenticated() {
 92            Task::ready(Ok(()))
 93        } else {
 94            let api_url = self.api_url.clone();
 95            cx.spawn(|mut cx| async move {
 96                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 97                    api_key
 98                } else {
 99                    let (_, api_key) = cx
100                        .update(|cx| cx.read_credentials(&api_url))?
101                        .await?
102                        .ok_or_else(|| anyhow!("credentials not found"))?;
103                    String::from_utf8(api_key)?
104                };
105                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
106                    if let CompletionProvider::OpenAi(provider) = provider {
107                        provider.api_key = Some(api_key);
108                    }
109                })
110            })
111        }
112    }
113
114    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
115        let delete_credentials = cx.delete_credentials(&self.api_url);
116        cx.spawn(|mut cx| async move {
117            delete_credentials.await.log_err();
118            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
119                if let CompletionProvider::OpenAi(provider) = provider {
120                    provider.api_key = None;
121                }
122            })
123        })
124    }
125
126    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
127        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
128            .into()
129    }
130
131    pub fn model(&self) -> OpenAiModel {
132        self.model.clone()
133    }
134
135    pub fn count_tokens(
136        &self,
137        request: LanguageModelRequest,
138        cx: &AppContext,
139    ) -> BoxFuture<'static, Result<usize>> {
140        count_open_ai_tokens(request, cx.background_executor())
141    }
142
143    pub fn complete(
144        &self,
145        request: LanguageModelRequest,
146    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
147        let request = self.to_open_ai_request(request);
148
149        let http_client = self.http_client.clone();
150        let api_key = self.api_key.clone();
151        let api_url = self.api_url.clone();
152        let low_speed_timeout = self.low_speed_timeout;
153        async move {
154            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
155            let request = stream_completion(
156                http_client.as_ref(),
157                &api_url,
158                &api_key,
159                request,
160                low_speed_timeout,
161            );
162            let response = request.await?;
163            let stream = response
164                .filter_map(|response| async move {
165                    match response {
166                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
167                        Err(error) => Some(Err(error)),
168                    }
169                })
170                .boxed();
171            Ok(stream)
172        }
173        .boxed()
174    }
175
176    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
177        let model = match request.model {
178            LanguageModel::OpenAi(model) => model,
179            _ => self.model(),
180        };
181
182        Request {
183            model,
184            messages: request
185                .messages
186                .into_iter()
187                .map(|msg| match msg.role {
188                    Role::User => RequestMessage::User {
189                        content: msg.content,
190                    },
191                    Role::Assistant => RequestMessage::Assistant {
192                        content: Some(msg.content),
193                        tool_calls: Vec::new(),
194                    },
195                    Role::System => RequestMessage::System {
196                        content: msg.content,
197                    },
198                })
199                .collect(),
200            stream: true,
201            stop: request.stop,
202            temperature: request.temperature,
203            tools: Vec::new(),
204            tool_choice: None,
205        }
206    }
207}
208
209pub fn count_open_ai_tokens(
210    request: LanguageModelRequest,
211    background_executor: &gpui::BackgroundExecutor,
212) -> BoxFuture<'static, Result<usize>> {
213    background_executor
214        .spawn(async move {
215            let messages = request
216                .messages
217                .into_iter()
218                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
219                    role: match message.role {
220                        Role::User => "user".into(),
221                        Role::Assistant => "assistant".into(),
222                        Role::System => "system".into(),
223                    },
224                    content: Some(message.content),
225                    name: None,
226                    function_call: None,
227                })
228                .collect::<Vec<_>>();
229
230            match request.model {
231                LanguageModel::Anthropic(_)
232                | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
233                | LanguageModel::Cloud(CloudModel::Claude3Opus)
234                | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
235                | LanguageModel::Cloud(CloudModel::Claude3Haiku)
236                | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
237                    // Tiktoken doesn't yet support these models, so we manually use the
238                    // same tokenizer as GPT-4.
239                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
240                }
241                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
242            }
243        })
244        .boxed()
245}
246
247impl From<Role> for open_ai::Role {
248    fn from(val: Role) -> Self {
249        match val {
250            Role::User => OpenAiRole::User,
251            Role::Assistant => OpenAiRole::Assistant,
252            Role::System => OpenAiRole::System,
253        }
254    }
255}
256
257struct AuthenticationPrompt {
258    api_key: View<Editor>,
259    api_url: String,
260}
261
262impl AuthenticationPrompt {
263    fn new(api_url: String, cx: &mut WindowContext) -> Self {
264        Self {
265            api_key: cx.new_view(|cx| {
266                let mut editor = Editor::single_line(cx);
267                editor.set_placeholder_text(
268                    "sk-000000000000000000000000000000000000000000000000",
269                    cx,
270                );
271                editor
272            }),
273            api_url,
274        }
275    }
276
277    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
278        let api_key = self.api_key.read(cx).text(cx);
279        if api_key.is_empty() {
280            return;
281        }
282
283        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
284        cx.spawn(|_, mut cx| async move {
285            write_credentials.await?;
286            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
287                if let CompletionProvider::OpenAi(provider) = provider {
288                    provider.api_key = Some(api_key);
289                }
290            })
291        })
292        .detach_and_log_err(cx);
293    }
294
295    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
296        let settings = ThemeSettings::get_global(cx);
297        let text_style = TextStyle {
298            color: cx.theme().colors().text,
299            font_family: settings.ui_font.family.clone(),
300            font_features: settings.ui_font.features.clone(),
301            font_size: rems(0.875).into(),
302            font_weight: settings.ui_font.weight,
303            font_style: FontStyle::Normal,
304            line_height: relative(1.3),
305            background_color: None,
306            underline: None,
307            strikethrough: None,
308            white_space: WhiteSpace::Normal,
309        };
310        EditorElement::new(
311            &self.api_key,
312            EditorStyle {
313                background: cx.theme().colors().editor_background,
314                local_player: cx.theme().players().local(),
315                text: text_style,
316                ..Default::default()
317            },
318        )
319    }
320}
321
322impl Render for AuthenticationPrompt {
323    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
324        const INSTRUCTIONS: [&str; 6] = [
325            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
326            " - You can create an API key at: platform.openai.com/api-keys",
327            " - Make sure your OpenAI account has credits",
328            " - Having a subscription for another service like GitHub Copilot won't work.",
329            "",
330            "Paste your OpenAI API key below and hit enter to use the assistant:",
331        ];
332
333        v_flex()
334            .p_4()
335            .size_full()
336            .on_action(cx.listener(Self::save_api_key))
337            .children(
338                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
339            )
340            .child(
341                h_flex()
342                    .w_full()
343                    .my_2()
344                    .px_2()
345                    .py_1()
346                    .bg(cx.theme().colors().editor_background)
347                    .rounded_md()
348                    .child(self.render_api_key_editor(cx)),
349            )
350            .child(
351                Label::new(
352                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
353                )
354                .size(LabelSize::Small),
355            )
356            .child(
357                h_flex()
358                    .gap_2()
359                    .child(Label::new("Click on").size(LabelSize::Small))
360                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
361                    .child(
362                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
363                    ),
364            )
365            .into_any()
366    }
367}