open_ai.rs

  1use crate::{
  2    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::{anyhow, Result};
  5use editor::{Editor, EditorElement, EditorStyle};
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
  8use http::HttpClient;
  9use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
 10use settings::Settings;
 11use std::time::Duration;
 12use std::{env, sync::Arc};
 13use theme::ThemeSettings;
 14use ui::prelude::*;
 15use util::ResultExt;
 16
 17pub struct OpenAiCompletionProvider {
 18    api_key: Option<String>,
 19    api_url: String,
 20    default_model: OpenAiModel,
 21    http_client: Arc<dyn HttpClient>,
 22    low_speed_timeout: Option<Duration>,
 23    settings_version: usize,
 24}
 25
 26impl OpenAiCompletionProvider {
 27    pub fn new(
 28        default_model: OpenAiModel,
 29        api_url: String,
 30        http_client: Arc<dyn HttpClient>,
 31        low_speed_timeout: Option<Duration>,
 32        settings_version: usize,
 33    ) -> Self {
 34        Self {
 35            api_key: None,
 36            api_url,
 37            default_model,
 38            http_client,
 39            low_speed_timeout,
 40            settings_version,
 41        }
 42    }
 43
 44    pub fn update(
 45        &mut self,
 46        default_model: OpenAiModel,
 47        api_url: String,
 48        low_speed_timeout: Option<Duration>,
 49        settings_version: usize,
 50    ) {
 51        self.default_model = default_model;
 52        self.api_url = api_url;
 53        self.low_speed_timeout = low_speed_timeout;
 54        self.settings_version = settings_version;
 55    }
 56
 57    pub fn settings_version(&self) -> usize {
 58        self.settings_version
 59    }
 60
 61    pub fn is_authenticated(&self) -> bool {
 62        self.api_key.is_some()
 63    }
 64
 65    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 66        if self.is_authenticated() {
 67            Task::ready(Ok(()))
 68        } else {
 69            let api_url = self.api_url.clone();
 70            cx.spawn(|mut cx| async move {
 71                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 72                    api_key
 73                } else {
 74                    let (_, api_key) = cx
 75                        .update(|cx| cx.read_credentials(&api_url))?
 76                        .await?
 77                        .ok_or_else(|| anyhow!("credentials not found"))?;
 78                    String::from_utf8(api_key)?
 79                };
 80                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 81                    if let CompletionProvider::OpenAi(provider) = provider {
 82                        provider.api_key = Some(api_key);
 83                    }
 84                })
 85            })
 86        }
 87    }
 88
 89    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 90        let delete_credentials = cx.delete_credentials(&self.api_url);
 91        cx.spawn(|mut cx| async move {
 92            delete_credentials.await.log_err();
 93            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 94                if let CompletionProvider::OpenAi(provider) = provider {
 95                    provider.api_key = None;
 96                }
 97            })
 98        })
 99    }
100
101    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
102        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
103            .into()
104    }
105
106    pub fn default_model(&self) -> OpenAiModel {
107        self.default_model.clone()
108    }
109
110    pub fn count_tokens(
111        &self,
112        request: LanguageModelRequest,
113        cx: &AppContext,
114    ) -> BoxFuture<'static, Result<usize>> {
115        count_open_ai_tokens(request, cx.background_executor())
116    }
117
118    pub fn complete(
119        &self,
120        request: LanguageModelRequest,
121    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
122        let request = self.to_open_ai_request(request);
123
124        let http_client = self.http_client.clone();
125        let api_key = self.api_key.clone();
126        let api_url = self.api_url.clone();
127        let low_speed_timeout = self.low_speed_timeout;
128        async move {
129            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
130            let request = stream_completion(
131                http_client.as_ref(),
132                &api_url,
133                &api_key,
134                request,
135                low_speed_timeout,
136            );
137            let response = request.await?;
138            let stream = response
139                .filter_map(|response| async move {
140                    match response {
141                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
142                        Err(error) => Some(Err(error)),
143                    }
144                })
145                .boxed();
146            Ok(stream)
147        }
148        .boxed()
149    }
150
151    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
152        let model = match request.model {
153            LanguageModel::ZedDotDev(_) => self.default_model(),
154            LanguageModel::OpenAi(model) => model,
155        };
156
157        Request {
158            model,
159            messages: request
160                .messages
161                .into_iter()
162                .map(|msg| match msg.role {
163                    Role::User => RequestMessage::User {
164                        content: msg.content,
165                    },
166                    Role::Assistant => RequestMessage::Assistant {
167                        content: Some(msg.content),
168                        tool_calls: Vec::new(),
169                    },
170                    Role::System => RequestMessage::System {
171                        content: msg.content,
172                    },
173                })
174                .collect(),
175            stream: true,
176            stop: request.stop,
177            temperature: request.temperature,
178            tools: Vec::new(),
179            tool_choice: None,
180        }
181    }
182}
183
184pub fn count_open_ai_tokens(
185    request: LanguageModelRequest,
186    background_executor: &gpui::BackgroundExecutor,
187) -> BoxFuture<'static, Result<usize>> {
188    background_executor
189        .spawn(async move {
190            let messages = request
191                .messages
192                .into_iter()
193                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
194                    role: match message.role {
195                        Role::User => "user".into(),
196                        Role::Assistant => "assistant".into(),
197                        Role::System => "system".into(),
198                    },
199                    content: Some(message.content),
200                    name: None,
201                    function_call: None,
202                })
203                .collect::<Vec<_>>();
204
205            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
206        })
207        .boxed()
208}
209
210impl From<Role> for open_ai::Role {
211    fn from(val: Role) -> Self {
212        match val {
213            Role::User => OpenAiRole::User,
214            Role::Assistant => OpenAiRole::Assistant,
215            Role::System => OpenAiRole::System,
216        }
217    }
218}
219
220struct AuthenticationPrompt {
221    api_key: View<Editor>,
222    api_url: String,
223}
224
225impl AuthenticationPrompt {
226    fn new(api_url: String, cx: &mut WindowContext) -> Self {
227        Self {
228            api_key: cx.new_view(|cx| {
229                let mut editor = Editor::single_line(cx);
230                editor.set_placeholder_text(
231                    "sk-000000000000000000000000000000000000000000000000",
232                    cx,
233                );
234                editor
235            }),
236            api_url,
237        }
238    }
239
240    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
241        let api_key = self.api_key.read(cx).text(cx);
242        if api_key.is_empty() {
243            return;
244        }
245
246        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
247        cx.spawn(|_, mut cx| async move {
248            write_credentials.await?;
249            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
250                if let CompletionProvider::OpenAi(provider) = provider {
251                    provider.api_key = Some(api_key);
252                }
253            })
254        })
255        .detach_and_log_err(cx);
256    }
257
258    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
259        let settings = ThemeSettings::get_global(cx);
260        let text_style = TextStyle {
261            color: cx.theme().colors().text,
262            font_family: settings.ui_font.family.clone(),
263            font_features: settings.ui_font.features.clone(),
264            font_size: rems(0.875).into(),
265            font_weight: FontWeight::NORMAL,
266            font_style: FontStyle::Normal,
267            line_height: relative(1.3),
268            background_color: None,
269            underline: None,
270            strikethrough: None,
271            white_space: WhiteSpace::Normal,
272        };
273        EditorElement::new(
274            &self.api_key,
275            EditorStyle {
276                background: cx.theme().colors().editor_background,
277                local_player: cx.theme().players().local(),
278                text: text_style,
279                ..Default::default()
280            },
281        )
282    }
283}
284
285impl Render for AuthenticationPrompt {
286    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
287        const INSTRUCTIONS: [&str; 6] = [
288            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
289            " - You can create an API key at: platform.openai.com/api-keys",
290            " - Make sure your OpenAI account has credits",
291            " - Having a subscription for another service like GitHub Copilot won't work.",
292            "",
293            "Paste your OpenAI API key below and hit enter to use the assistant:",
294        ];
295
296        v_flex()
297            .p_4()
298            .size_full()
299            .on_action(cx.listener(Self::save_api_key))
300            .children(
301                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
302            )
303            .child(
304                h_flex()
305                    .w_full()
306                    .my_2()
307                    .px_2()
308                    .py_1()
309                    .bg(cx.theme().colors().editor_background)
310                    .rounded_md()
311                    .child(self.render_api_key_editor(cx)),
312            )
313            .child(
314                Label::new(
315                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
316                )
317                .size(LabelSize::Small),
318            )
319            .child(
320                h_flex()
321                    .gap_2()
322                    .child(Label::new("Click on").size(LabelSize::Small))
323                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
324                    .child(
325                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
326                    ),
327            )
328            .into_any()
329    }
330}