open_ai.rs

  1use crate::CompletionProvider;
  2use crate::LanguageModelCompletionProvider;
  3use anyhow::{anyhow, Result};
  4use editor::{Editor, EditorElement, EditorStyle};
  5use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  6use gpui::{AnyView, AppContext, Task, TextStyle, View};
  7use http::HttpClient;
  8use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
  9use open_ai::Model as OpenAiModel;
 10use open_ai::{stream_completion, Request, RequestMessage};
 11use settings::Settings;
 12use std::time::Duration;
 13use std::{env, sync::Arc};
 14use strum::IntoEnumIterator;
 15use theme::ThemeSettings;
 16use ui::prelude::*;
 17use util::ResultExt;
 18
 19pub struct OpenAiCompletionProvider {
 20    api_key: Option<String>,
 21    api_url: String,
 22    model: OpenAiModel,
 23    http_client: Arc<dyn HttpClient>,
 24    low_speed_timeout: Option<Duration>,
 25    settings_version: usize,
 26    available_models_from_settings: Vec<OpenAiModel>,
 27}
 28
 29impl OpenAiCompletionProvider {
 30    pub fn new(
 31        model: OpenAiModel,
 32        api_url: String,
 33        http_client: Arc<dyn HttpClient>,
 34        low_speed_timeout: Option<Duration>,
 35        settings_version: usize,
 36        available_models_from_settings: Vec<OpenAiModel>,
 37    ) -> Self {
 38        Self {
 39            api_key: None,
 40            api_url,
 41            model,
 42            http_client,
 43            low_speed_timeout,
 44            settings_version,
 45            available_models_from_settings,
 46        }
 47    }
 48
 49    pub fn update(
 50        &mut self,
 51        model: OpenAiModel,
 52        api_url: String,
 53        low_speed_timeout: Option<Duration>,
 54        settings_version: usize,
 55    ) {
 56        self.model = model;
 57        self.api_url = api_url;
 58        self.low_speed_timeout = low_speed_timeout;
 59        self.settings_version = settings_version;
 60    }
 61
 62    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
 63        let model = match request.model {
 64            LanguageModel::OpenAi(model) => model,
 65            _ => self.model.clone(),
 66        };
 67
 68        Request {
 69            model,
 70            messages: request
 71                .messages
 72                .into_iter()
 73                .map(|msg| match msg.role {
 74                    Role::User => RequestMessage::User {
 75                        content: msg.content,
 76                    },
 77                    Role::Assistant => RequestMessage::Assistant {
 78                        content: Some(msg.content),
 79                        tool_calls: Vec::new(),
 80                    },
 81                    Role::System => RequestMessage::System {
 82                        content: msg.content,
 83                    },
 84                })
 85                .collect(),
 86            stream: true,
 87            stop: request.stop,
 88            temperature: request.temperature,
 89            tools: Vec::new(),
 90            tool_choice: None,
 91        }
 92    }
 93}
 94
 95impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
 96    fn available_models(&self) -> Vec<LanguageModel> {
 97        if self.available_models_from_settings.is_empty() {
 98            let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
 99                vec![self.model.clone()]
100            } else {
101                OpenAiModel::iter()
102                    .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
103                    .collect()
104            };
105            available_models
106                .into_iter()
107                .map(LanguageModel::OpenAi)
108                .collect()
109        } else {
110            self.available_models_from_settings
111                .iter()
112                .cloned()
113                .map(LanguageModel::OpenAi)
114                .collect()
115        }
116    }
117
118    fn settings_version(&self) -> usize {
119        self.settings_version
120    }
121
122    fn is_authenticated(&self) -> bool {
123        self.api_key.is_some()
124    }
125
126    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
127        if self.is_authenticated() {
128            Task::ready(Ok(()))
129        } else {
130            let api_url = self.api_url.clone();
131            cx.spawn(|mut cx| async move {
132                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
133                    api_key
134                } else {
135                    let (_, api_key) = cx
136                        .update(|cx| cx.read_credentials(&api_url))?
137                        .await?
138                        .ok_or_else(|| anyhow!("credentials not found"))?;
139                    String::from_utf8(api_key)?
140                };
141                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
142                    provider.update_current_as::<_, Self>(|provider| {
143                        provider.api_key = Some(api_key);
144                    });
145                })
146            })
147        }
148    }
149
150    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
151        let delete_credentials = cx.delete_credentials(&self.api_url);
152        cx.spawn(|mut cx| async move {
153            delete_credentials.await.log_err();
154            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
155                provider.update_current_as::<_, Self>(|provider| {
156                    provider.api_key = None;
157                });
158            })
159        })
160    }
161
162    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
163        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
164            .into()
165    }
166
167    fn model(&self) -> LanguageModel {
168        LanguageModel::OpenAi(self.model.clone())
169    }
170
171    fn count_tokens(
172        &self,
173        request: LanguageModelRequest,
174        cx: &AppContext,
175    ) -> BoxFuture<'static, Result<usize>> {
176        count_open_ai_tokens(request, cx.background_executor())
177    }
178
179    fn stream_completion(
180        &self,
181        request: LanguageModelRequest,
182    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
183        let request = self.to_open_ai_request(request);
184
185        let http_client = self.http_client.clone();
186        let api_key = self.api_key.clone();
187        let api_url = self.api_url.clone();
188        let low_speed_timeout = self.low_speed_timeout;
189        async move {
190            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
191            let request = stream_completion(
192                http_client.as_ref(),
193                &api_url,
194                &api_key,
195                request,
196                low_speed_timeout,
197            );
198            let response = request.await?;
199            let stream = response
200                .filter_map(|response| async move {
201                    match response {
202                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
203                        Err(error) => Some(Err(error)),
204                    }
205                })
206                .boxed();
207            Ok(stream)
208        }
209        .boxed()
210    }
211
212    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
213        self
214    }
215}
216
217pub fn count_open_ai_tokens(
218    request: LanguageModelRequest,
219    background_executor: &gpui::BackgroundExecutor,
220) -> BoxFuture<'static, Result<usize>> {
221    background_executor
222        .spawn(async move {
223            let messages = request
224                .messages
225                .into_iter()
226                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
227                    role: match message.role {
228                        Role::User => "user".into(),
229                        Role::Assistant => "assistant".into(),
230                        Role::System => "system".into(),
231                    },
232                    content: Some(message.content),
233                    name: None,
234                    function_call: None,
235                })
236                .collect::<Vec<_>>();
237
238            match request.model {
239                LanguageModel::Anthropic(_)
240                | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
241                | LanguageModel::Cloud(CloudModel::Claude3Opus)
242                | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
243                | LanguageModel::Cloud(CloudModel::Claude3Haiku)
244                | LanguageModel::Cloud(CloudModel::Custom { .. })
245                | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
246                    // Tiktoken doesn't yet support these models, so we manually use the
247                    // same tokenizer as GPT-4.
248                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
249                }
250                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
251            }
252        })
253        .boxed()
254}
255
256struct AuthenticationPrompt {
257    api_key: View<Editor>,
258    api_url: String,
259}
260
261impl AuthenticationPrompt {
262    fn new(api_url: String, cx: &mut WindowContext) -> Self {
263        Self {
264            api_key: cx.new_view(|cx| {
265                let mut editor = Editor::single_line(cx);
266                editor.set_placeholder_text(
267                    "sk-000000000000000000000000000000000000000000000000",
268                    cx,
269                );
270                editor
271            }),
272            api_url,
273        }
274    }
275
276    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
277        let api_key = self.api_key.read(cx).text(cx);
278        if api_key.is_empty() {
279            return;
280        }
281
282        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
283        cx.spawn(|_, mut cx| async move {
284            write_credentials.await?;
285            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
286                provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
287                    provider.api_key = Some(api_key);
288                });
289            })
290        })
291        .detach_and_log_err(cx);
292    }
293
294    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
295        let settings = ThemeSettings::get_global(cx);
296        let text_style = TextStyle {
297            color: cx.theme().colors().text,
298            font_family: settings.ui_font.family.clone(),
299            font_features: settings.ui_font.features.clone(),
300            font_size: rems(0.875).into(),
301            font_weight: settings.ui_font.weight,
302            line_height: relative(1.3),
303            ..Default::default()
304        };
305        EditorElement::new(
306            &self.api_key,
307            EditorStyle {
308                background: cx.theme().colors().editor_background,
309                local_player: cx.theme().players().local(),
310                text: text_style,
311                ..Default::default()
312            },
313        )
314    }
315}
316
317impl Render for AuthenticationPrompt {
318    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
319        const INSTRUCTIONS: [&str; 6] = [
320            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
321            " - You can create an API key at: platform.openai.com/api-keys",
322            " - Make sure your OpenAI account has credits",
323            " - Having a subscription for another service like GitHub Copilot won't work.",
324            "",
325            "Paste your OpenAI API key below and hit enter to use the assistant:",
326        ];
327
328        v_flex()
329            .p_4()
330            .size_full()
331            .on_action(cx.listener(Self::save_api_key))
332            .children(
333                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
334            )
335            .child(
336                h_flex()
337                    .w_full()
338                    .my_2()
339                    .px_2()
340                    .py_1()
341                    .bg(cx.theme().colors().editor_background)
342                    .rounded_md()
343                    .child(self.render_api_key_editor(cx)),
344            )
345            .child(
346                Label::new(
347                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
348                )
349                .size(LabelSize::Small),
350            )
351            .child(
352                h_flex()
353                    .gap_2()
354                    .child(Label::new("Click on").size(LabelSize::Small))
355                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
356                    .child(
357                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
358                    ),
359            )
360            .into_any()
361    }
362}