open_ai.rs

  1use crate::assistant_settings::CloudModel;
  2use crate::{
  3    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  4};
  5use anyhow::{anyhow, Result};
  6use editor::{Editor, EditorElement, EditorStyle};
  7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  8use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
  9use http::HttpClient;
 10use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 11use settings::Settings;
 12use std::time::Duration;
 13use std::{env, sync::Arc};
 14use strum::IntoEnumIterator;
 15use theme::ThemeSettings;
 16use ui::prelude::*;
 17use util::ResultExt;
 18
 19pub struct OpenAiCompletionProvider {
 20    api_key: Option<String>,
 21    api_url: String,
 22    model: OpenAiModel,
 23    http_client: Arc<dyn HttpClient>,
 24    low_speed_timeout: Option<Duration>,
 25    settings_version: usize,
 26}
 27
 28impl OpenAiCompletionProvider {
 29    pub fn new(
 30        model: OpenAiModel,
 31        api_url: String,
 32        http_client: Arc<dyn HttpClient>,
 33        low_speed_timeout: Option<Duration>,
 34        settings_version: usize,
 35    ) -> Self {
 36        Self {
 37            api_key: None,
 38            api_url,
 39            model,
 40            http_client,
 41            low_speed_timeout,
 42            settings_version,
 43        }
 44    }
 45
 46    pub fn update(
 47        &mut self,
 48        model: OpenAiModel,
 49        api_url: String,
 50        low_speed_timeout: Option<Duration>,
 51        settings_version: usize,
 52    ) {
 53        self.model = model;
 54        self.api_url = api_url;
 55        self.low_speed_timeout = low_speed_timeout;
 56        self.settings_version = settings_version;
 57    }
 58
 59    pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
 60        OpenAiModel::iter()
 61    }
 62
 63    pub fn settings_version(&self) -> usize {
 64        self.settings_version
 65    }
 66
 67    pub fn is_authenticated(&self) -> bool {
 68        self.api_key.is_some()
 69    }
 70
 71    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 72        if self.is_authenticated() {
 73            Task::ready(Ok(()))
 74        } else {
 75            let api_url = self.api_url.clone();
 76            cx.spawn(|mut cx| async move {
 77                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 78                    api_key
 79                } else {
 80                    let (_, api_key) = cx
 81                        .update(|cx| cx.read_credentials(&api_url))?
 82                        .await?
 83                        .ok_or_else(|| anyhow!("credentials not found"))?;
 84                    String::from_utf8(api_key)?
 85                };
 86                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 87                    if let CompletionProvider::OpenAi(provider) = provider {
 88                        provider.api_key = Some(api_key);
 89                    }
 90                })
 91            })
 92        }
 93    }
 94
 95    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 96        let delete_credentials = cx.delete_credentials(&self.api_url);
 97        cx.spawn(|mut cx| async move {
 98            delete_credentials.await.log_err();
 99            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
100                if let CompletionProvider::OpenAi(provider) = provider {
101                    provider.api_key = None;
102                }
103            })
104        })
105    }
106
107    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
108        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
109            .into()
110    }
111
112    pub fn model(&self) -> OpenAiModel {
113        self.model.clone()
114    }
115
116    pub fn count_tokens(
117        &self,
118        request: LanguageModelRequest,
119        cx: &AppContext,
120    ) -> BoxFuture<'static, Result<usize>> {
121        count_open_ai_tokens(request, cx.background_executor())
122    }
123
124    pub fn complete(
125        &self,
126        request: LanguageModelRequest,
127    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
128        let request = self.to_open_ai_request(request);
129
130        let http_client = self.http_client.clone();
131        let api_key = self.api_key.clone();
132        let api_url = self.api_url.clone();
133        let low_speed_timeout = self.low_speed_timeout;
134        async move {
135            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
136            let request = stream_completion(
137                http_client.as_ref(),
138                &api_url,
139                &api_key,
140                request,
141                low_speed_timeout,
142            );
143            let response = request.await?;
144            let stream = response
145                .filter_map(|response| async move {
146                    match response {
147                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
148                        Err(error) => Some(Err(error)),
149                    }
150                })
151                .boxed();
152            Ok(stream)
153        }
154        .boxed()
155    }
156
157    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
158        let model = match request.model {
159            LanguageModel::OpenAi(model) => model,
160            _ => self.model(),
161        };
162
163        Request {
164            model,
165            messages: request
166                .messages
167                .into_iter()
168                .map(|msg| match msg.role {
169                    Role::User => RequestMessage::User {
170                        content: msg.content,
171                    },
172                    Role::Assistant => RequestMessage::Assistant {
173                        content: Some(msg.content),
174                        tool_calls: Vec::new(),
175                    },
176                    Role::System => RequestMessage::System {
177                        content: msg.content,
178                    },
179                })
180                .collect(),
181            stream: true,
182            stop: request.stop,
183            temperature: request.temperature,
184            tools: Vec::new(),
185            tool_choice: None,
186        }
187    }
188}
189
190pub fn count_open_ai_tokens(
191    request: LanguageModelRequest,
192    background_executor: &gpui::BackgroundExecutor,
193) -> BoxFuture<'static, Result<usize>> {
194    background_executor
195        .spawn(async move {
196            let messages = request
197                .messages
198                .into_iter()
199                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
200                    role: match message.role {
201                        Role::User => "user".into(),
202                        Role::Assistant => "assistant".into(),
203                        Role::System => "system".into(),
204                    },
205                    content: Some(message.content),
206                    name: None,
207                    function_call: None,
208                })
209                .collect::<Vec<_>>();
210
211            match request.model {
212                LanguageModel::Anthropic(_)
213                | LanguageModel::Cloud(CloudModel::Claude3Opus)
214                | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
215                | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
216                    // Tiktoken doesn't yet support these models, so we manually use the
217                    // same tokenizer as GPT-4.
218                    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
219                }
220                _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
221            }
222        })
223        .boxed()
224}
225
226impl From<Role> for open_ai::Role {
227    fn from(val: Role) -> Self {
228        match val {
229            Role::User => OpenAiRole::User,
230            Role::Assistant => OpenAiRole::Assistant,
231            Role::System => OpenAiRole::System,
232        }
233    }
234}
235
236struct AuthenticationPrompt {
237    api_key: View<Editor>,
238    api_url: String,
239}
240
241impl AuthenticationPrompt {
242    fn new(api_url: String, cx: &mut WindowContext) -> Self {
243        Self {
244            api_key: cx.new_view(|cx| {
245                let mut editor = Editor::single_line(cx);
246                editor.set_placeholder_text(
247                    "sk-000000000000000000000000000000000000000000000000",
248                    cx,
249                );
250                editor
251            }),
252            api_url,
253        }
254    }
255
256    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
257        let api_key = self.api_key.read(cx).text(cx);
258        if api_key.is_empty() {
259            return;
260        }
261
262        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
263        cx.spawn(|_, mut cx| async move {
264            write_credentials.await?;
265            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
266                if let CompletionProvider::OpenAi(provider) = provider {
267                    provider.api_key = Some(api_key);
268                }
269            })
270        })
271        .detach_and_log_err(cx);
272    }
273
274    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
275        let settings = ThemeSettings::get_global(cx);
276        let text_style = TextStyle {
277            color: cx.theme().colors().text,
278            font_family: settings.ui_font.family.clone(),
279            font_features: settings.ui_font.features.clone(),
280            font_size: rems(0.875).into(),
281            font_weight: settings.ui_font.weight,
282            font_style: FontStyle::Normal,
283            line_height: relative(1.3),
284            background_color: None,
285            underline: None,
286            strikethrough: None,
287            white_space: WhiteSpace::Normal,
288        };
289        EditorElement::new(
290            &self.api_key,
291            EditorStyle {
292                background: cx.theme().colors().editor_background,
293                local_player: cx.theme().players().local(),
294                text: text_style,
295                ..Default::default()
296            },
297        )
298    }
299}
300
301impl Render for AuthenticationPrompt {
302    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
303        const INSTRUCTIONS: [&str; 6] = [
304            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
305            " - You can create an API key at: platform.openai.com/api-keys",
306            " - Make sure your OpenAI account has credits",
307            " - Having a subscription for another service like GitHub Copilot won't work.",
308            "",
309            "Paste your OpenAI API key below and hit enter to use the assistant:",
310        ];
311
312        v_flex()
313            .p_4()
314            .size_full()
315            .on_action(cx.listener(Self::save_api_key))
316            .children(
317                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
318            )
319            .child(
320                h_flex()
321                    .w_full()
322                    .my_2()
323                    .px_2()
324                    .py_1()
325                    .bg(cx.theme().colors().editor_background)
326                    .rounded_md()
327                    .child(self.render_api_key_editor(cx)),
328            )
329            .child(
330                Label::new(
331                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
332                )
333                .size(LabelSize::Small),
334            )
335            .child(
336                h_flex()
337                    .gap_2()
338                    .child(Label::new("Click on").size(LabelSize::Small))
339                    .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
340                    .child(
341                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
342                    ),
343            )
344            .into_any()
345    }
346}