open_ai.rs

  1use crate::{
  2    assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
  3};
  4use anyhow::{anyhow, Result};
  5use editor::{Editor, EditorElement, EditorStyle};
  6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
  8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
  9use settings::Settings;
 10use std::time::Duration;
 11use std::{env, sync::Arc};
 12use theme::ThemeSettings;
 13use ui::prelude::*;
 14use util::{http::HttpClient, ResultExt};
 15
 16pub struct OpenAiCompletionProvider {
 17    api_key: Option<String>,
 18    api_url: String,
 19    default_model: OpenAiModel,
 20    http_client: Arc<dyn HttpClient>,
 21    low_speed_timeout: Option<Duration>,
 22    settings_version: usize,
 23}
 24
 25impl OpenAiCompletionProvider {
 26    pub fn new(
 27        default_model: OpenAiModel,
 28        api_url: String,
 29        http_client: Arc<dyn HttpClient>,
 30        low_speed_timeout: Option<Duration>,
 31        settings_version: usize,
 32    ) -> Self {
 33        Self {
 34            api_key: None,
 35            api_url,
 36            default_model,
 37            http_client,
 38            low_speed_timeout,
 39            settings_version,
 40        }
 41    }
 42
 43    pub fn update(
 44        &mut self,
 45        default_model: OpenAiModel,
 46        api_url: String,
 47        low_speed_timeout: Option<Duration>,
 48        settings_version: usize,
 49    ) {
 50        self.default_model = default_model;
 51        self.api_url = api_url;
 52        self.low_speed_timeout = low_speed_timeout;
 53        self.settings_version = settings_version;
 54    }
 55
 56    pub fn settings_version(&self) -> usize {
 57        self.settings_version
 58    }
 59
 60    pub fn is_authenticated(&self) -> bool {
 61        self.api_key.is_some()
 62    }
 63
 64    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 65        if self.is_authenticated() {
 66            Task::ready(Ok(()))
 67        } else {
 68            let api_url = self.api_url.clone();
 69            cx.spawn(|mut cx| async move {
 70                let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
 71                    api_key
 72                } else {
 73                    let (_, api_key) = cx
 74                        .update(|cx| cx.read_credentials(&api_url))?
 75                        .await?
 76                        .ok_or_else(|| anyhow!("credentials not found"))?;
 77                    String::from_utf8(api_key)?
 78                };
 79                cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 80                    if let CompletionProvider::OpenAi(provider) = provider {
 81                        provider.api_key = Some(api_key);
 82                    }
 83                })
 84            })
 85        }
 86    }
 87
 88    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 89        let delete_credentials = cx.delete_credentials(&self.api_url);
 90        cx.spawn(|mut cx| async move {
 91            delete_credentials.await.log_err();
 92            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
 93                if let CompletionProvider::OpenAi(provider) = provider {
 94                    provider.api_key = None;
 95                }
 96            })
 97        })
 98    }
 99
100    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
101        cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
102            .into()
103    }
104
105    pub fn default_model(&self) -> OpenAiModel {
106        self.default_model.clone()
107    }
108
109    pub fn count_tokens(
110        &self,
111        request: LanguageModelRequest,
112        cx: &AppContext,
113    ) -> BoxFuture<'static, Result<usize>> {
114        count_open_ai_tokens(request, cx.background_executor())
115    }
116
117    pub fn complete(
118        &self,
119        request: LanguageModelRequest,
120    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
121        let request = self.to_open_ai_request(request);
122
123        let http_client = self.http_client.clone();
124        let api_key = self.api_key.clone();
125        let api_url = self.api_url.clone();
126        let low_speed_timeout = self.low_speed_timeout;
127        async move {
128            let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
129            let request = stream_completion(
130                http_client.as_ref(),
131                &api_url,
132                &api_key,
133                request,
134                low_speed_timeout,
135            );
136            let response = request.await?;
137            let stream = response
138                .filter_map(|response| async move {
139                    match response {
140                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
141                        Err(error) => Some(Err(error)),
142                    }
143                })
144                .boxed();
145            Ok(stream)
146        }
147        .boxed()
148    }
149
150    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
151        let model = match request.model {
152            LanguageModel::ZedDotDev(_) => self.default_model(),
153            LanguageModel::OpenAi(model) => model,
154        };
155
156        Request {
157            model,
158            messages: request
159                .messages
160                .into_iter()
161                .map(|msg| match msg.role {
162                    Role::User => RequestMessage::User {
163                        content: msg.content,
164                    },
165                    Role::Assistant => RequestMessage::Assistant {
166                        content: Some(msg.content),
167                        tool_calls: Vec::new(),
168                    },
169                    Role::System => RequestMessage::System {
170                        content: msg.content,
171                    },
172                })
173                .collect(),
174            stream: true,
175            stop: request.stop,
176            temperature: request.temperature,
177            tools: Vec::new(),
178            tool_choice: None,
179        }
180    }
181}
182
183pub fn count_open_ai_tokens(
184    request: LanguageModelRequest,
185    background_executor: &gpui::BackgroundExecutor,
186) -> BoxFuture<'static, Result<usize>> {
187    background_executor
188        .spawn(async move {
189            let messages = request
190                .messages
191                .into_iter()
192                .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
193                    role: match message.role {
194                        Role::User => "user".into(),
195                        Role::Assistant => "assistant".into(),
196                        Role::System => "system".into(),
197                    },
198                    content: Some(message.content),
199                    name: None,
200                    function_call: None,
201                })
202                .collect::<Vec<_>>();
203
204            tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
205        })
206        .boxed()
207}
208
209impl From<Role> for open_ai::Role {
210    fn from(val: Role) -> Self {
211        match val {
212            Role::User => OpenAiRole::User,
213            Role::Assistant => OpenAiRole::Assistant,
214            Role::System => OpenAiRole::System,
215        }
216    }
217}
218
219struct AuthenticationPrompt {
220    api_key: View<Editor>,
221    api_url: String,
222}
223
224impl AuthenticationPrompt {
225    fn new(api_url: String, cx: &mut WindowContext) -> Self {
226        Self {
227            api_key: cx.new_view(|cx| {
228                let mut editor = Editor::single_line(cx);
229                editor.set_placeholder_text(
230                    "sk-000000000000000000000000000000000000000000000000",
231                    cx,
232                );
233                editor
234            }),
235            api_url,
236        }
237    }
238
239    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
240        let api_key = self.api_key.read(cx).text(cx);
241        if api_key.is_empty() {
242            return;
243        }
244
245        let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
246        cx.spawn(|_, mut cx| async move {
247            write_credentials.await?;
248            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
249                if let CompletionProvider::OpenAi(provider) = provider {
250                    provider.api_key = Some(api_key);
251                }
252            })
253        })
254        .detach_and_log_err(cx);
255    }
256
257    fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
258        let settings = ThemeSettings::get_global(cx);
259        let text_style = TextStyle {
260            color: cx.theme().colors().text,
261            font_family: settings.ui_font.family.clone(),
262            font_features: settings.ui_font.features.clone(),
263            font_size: rems(0.875).into(),
264            font_weight: FontWeight::NORMAL,
265            font_style: FontStyle::Normal,
266            line_height: relative(1.3),
267            background_color: None,
268            underline: None,
269            strikethrough: None,
270            white_space: WhiteSpace::Normal,
271        };
272        EditorElement::new(
273            &self.api_key,
274            EditorStyle {
275                background: cx.theme().colors().editor_background,
276                local_player: cx.theme().players().local(),
277                text: text_style,
278                ..Default::default()
279            },
280        )
281    }
282}
283
284impl Render for AuthenticationPrompt {
285    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
286        const INSTRUCTIONS: [&str; 6] = [
287            "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
288            " - You can create an API key at: platform.openai.com/api-keys",
289            " - Make sure your OpenAI account has credits",
290            " - Having a subscription for another service like GitHub Copilot won't work.",
291            "",
292            "Paste your OpenAI API key below and hit enter to use the assistant:",
293        ];
294
295        v_flex()
296            .p_4()
297            .size_full()
298            .on_action(cx.listener(Self::save_api_key))
299            .children(
300                INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
301            )
302            .child(
303                h_flex()
304                    .w_full()
305                    .my_2()
306                    .px_2()
307                    .py_1()
308                    .bg(cx.theme().colors().editor_background)
309                    .rounded_md()
310                    .child(self.render_api_key_editor(cx)),
311            )
312            .child(
313                Label::new(
314                    "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
315                )
316                .size(LabelSize::Small),
317            )
318            .child(
319                h_flex()
320                    .gap_2()
321                    .child(Label::new("Click on").size(LabelSize::Small))
322                    .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
323                    .child(
324                        Label::new("in the status bar to close this panel.").size(LabelSize::Small),
325                    ),
326            )
327            .into_any()
328    }
329}