1use crate::{
2 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::{anyhow, Result};
5use editor::{Editor, EditorElement, EditorStyle};
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
8use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
9use settings::Settings;
10use std::time::Duration;
11use std::{env, sync::Arc};
12use theme::ThemeSettings;
13use ui::prelude::*;
14use util::{http::HttpClient, ResultExt};
15
16pub struct OpenAiCompletionProvider {
17 api_key: Option<String>,
18 api_url: String,
19 default_model: OpenAiModel,
20 http_client: Arc<dyn HttpClient>,
21 low_speed_timeout: Option<Duration>,
22 settings_version: usize,
23}
24
25impl OpenAiCompletionProvider {
26 pub fn new(
27 default_model: OpenAiModel,
28 api_url: String,
29 http_client: Arc<dyn HttpClient>,
30 low_speed_timeout: Option<Duration>,
31 settings_version: usize,
32 ) -> Self {
33 Self {
34 api_key: None,
35 api_url,
36 default_model,
37 http_client,
38 low_speed_timeout,
39 settings_version,
40 }
41 }
42
43 pub fn update(
44 &mut self,
45 default_model: OpenAiModel,
46 api_url: String,
47 low_speed_timeout: Option<Duration>,
48 settings_version: usize,
49 ) {
50 self.default_model = default_model;
51 self.api_url = api_url;
52 self.low_speed_timeout = low_speed_timeout;
53 self.settings_version = settings_version;
54 }
55
56 pub fn settings_version(&self) -> usize {
57 self.settings_version
58 }
59
60 pub fn is_authenticated(&self) -> bool {
61 self.api_key.is_some()
62 }
63
64 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
65 if self.is_authenticated() {
66 Task::ready(Ok(()))
67 } else {
68 let api_url = self.api_url.clone();
69 cx.spawn(|mut cx| async move {
70 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
71 api_key
72 } else {
73 let (_, api_key) = cx
74 .update(|cx| cx.read_credentials(&api_url))?
75 .await?
76 .ok_or_else(|| anyhow!("credentials not found"))?;
77 String::from_utf8(api_key)?
78 };
79 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
80 if let CompletionProvider::OpenAi(provider) = provider {
81 provider.api_key = Some(api_key);
82 }
83 })
84 })
85 }
86 }
87
88 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
89 let delete_credentials = cx.delete_credentials(&self.api_url);
90 cx.spawn(|mut cx| async move {
91 delete_credentials.await.log_err();
92 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
93 if let CompletionProvider::OpenAi(provider) = provider {
94 provider.api_key = None;
95 }
96 })
97 })
98 }
99
100 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
101 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
102 .into()
103 }
104
105 pub fn default_model(&self) -> OpenAiModel {
106 self.default_model.clone()
107 }
108
109 pub fn count_tokens(
110 &self,
111 request: LanguageModelRequest,
112 cx: &AppContext,
113 ) -> BoxFuture<'static, Result<usize>> {
114 count_open_ai_tokens(request, cx.background_executor())
115 }
116
117 pub fn complete(
118 &self,
119 request: LanguageModelRequest,
120 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
121 let request = self.to_open_ai_request(request);
122
123 let http_client = self.http_client.clone();
124 let api_key = self.api_key.clone();
125 let api_url = self.api_url.clone();
126 let low_speed_timeout = self.low_speed_timeout;
127 async move {
128 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
129 let request = stream_completion(
130 http_client.as_ref(),
131 &api_url,
132 &api_key,
133 request,
134 low_speed_timeout,
135 );
136 let response = request.await?;
137 let stream = response
138 .filter_map(|response| async move {
139 match response {
140 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
141 Err(error) => Some(Err(error)),
142 }
143 })
144 .boxed();
145 Ok(stream)
146 }
147 .boxed()
148 }
149
150 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
151 let model = match request.model {
152 LanguageModel::ZedDotDev(_) => self.default_model(),
153 LanguageModel::OpenAi(model) => model,
154 };
155
156 Request {
157 model,
158 messages: request
159 .messages
160 .into_iter()
161 .map(|msg| match msg.role {
162 Role::User => RequestMessage::User {
163 content: msg.content,
164 },
165 Role::Assistant => RequestMessage::Assistant {
166 content: Some(msg.content),
167 tool_calls: Vec::new(),
168 },
169 Role::System => RequestMessage::System {
170 content: msg.content,
171 },
172 })
173 .collect(),
174 stream: true,
175 stop: request.stop,
176 temperature: request.temperature,
177 tools: Vec::new(),
178 tool_choice: None,
179 }
180 }
181}
182
183pub fn count_open_ai_tokens(
184 request: LanguageModelRequest,
185 background_executor: &gpui::BackgroundExecutor,
186) -> BoxFuture<'static, Result<usize>> {
187 background_executor
188 .spawn(async move {
189 let messages = request
190 .messages
191 .into_iter()
192 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
193 role: match message.role {
194 Role::User => "user".into(),
195 Role::Assistant => "assistant".into(),
196 Role::System => "system".into(),
197 },
198 content: Some(message.content),
199 name: None,
200 function_call: None,
201 })
202 .collect::<Vec<_>>();
203
204 tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
205 })
206 .boxed()
207}
208
209impl From<Role> for open_ai::Role {
210 fn from(val: Role) -> Self {
211 match val {
212 Role::User => OpenAiRole::User,
213 Role::Assistant => OpenAiRole::Assistant,
214 Role::System => OpenAiRole::System,
215 }
216 }
217}
218
219struct AuthenticationPrompt {
220 api_key: View<Editor>,
221 api_url: String,
222}
223
224impl AuthenticationPrompt {
225 fn new(api_url: String, cx: &mut WindowContext) -> Self {
226 Self {
227 api_key: cx.new_view(|cx| {
228 let mut editor = Editor::single_line(cx);
229 editor.set_placeholder_text(
230 "sk-000000000000000000000000000000000000000000000000",
231 cx,
232 );
233 editor
234 }),
235 api_url,
236 }
237 }
238
239 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
240 let api_key = self.api_key.read(cx).text(cx);
241 if api_key.is_empty() {
242 return;
243 }
244
245 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
246 cx.spawn(|_, mut cx| async move {
247 write_credentials.await?;
248 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
249 if let CompletionProvider::OpenAi(provider) = provider {
250 provider.api_key = Some(api_key);
251 }
252 })
253 })
254 .detach_and_log_err(cx);
255 }
256
257 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
258 let settings = ThemeSettings::get_global(cx);
259 let text_style = TextStyle {
260 color: cx.theme().colors().text,
261 font_family: settings.ui_font.family.clone(),
262 font_features: settings.ui_font.features.clone(),
263 font_size: rems(0.875).into(),
264 font_weight: FontWeight::NORMAL,
265 font_style: FontStyle::Normal,
266 line_height: relative(1.3),
267 background_color: None,
268 underline: None,
269 strikethrough: None,
270 white_space: WhiteSpace::Normal,
271 };
272 EditorElement::new(
273 &self.api_key,
274 EditorStyle {
275 background: cx.theme().colors().editor_background,
276 local_player: cx.theme().players().local(),
277 text: text_style,
278 ..Default::default()
279 },
280 )
281 }
282}
283
284impl Render for AuthenticationPrompt {
285 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
286 const INSTRUCTIONS: [&str; 6] = [
287 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
288 " - You can create an API key at: platform.openai.com/api-keys",
289 " - Make sure your OpenAI account has credits",
290 " - Having a subscription for another service like GitHub Copilot won't work.",
291 "",
292 "Paste your OpenAI API key below and hit enter to use the assistant:",
293 ];
294
295 v_flex()
296 .p_4()
297 .size_full()
298 .on_action(cx.listener(Self::save_api_key))
299 .children(
300 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
301 )
302 .child(
303 h_flex()
304 .w_full()
305 .my_2()
306 .px_2()
307 .py_1()
308 .bg(cx.theme().colors().editor_background)
309 .rounded_md()
310 .child(self.render_api_key_editor(cx)),
311 )
312 .child(
313 Label::new(
314 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
315 )
316 .size(LabelSize::Small),
317 )
318 .child(
319 h_flex()
320 .gap_2()
321 .child(Label::new("Click on").size(LabelSize::Small))
322 .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
323 .child(
324 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
325 ),
326 )
327 .into_any()
328 }
329}