1use crate::{
2 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::{anyhow, Result};
5use editor::{Editor, EditorElement, EditorStyle};
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
7use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
8use http::HttpClient;
9use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
10use settings::Settings;
11use std::time::Duration;
12use std::{env, sync::Arc};
13use theme::ThemeSettings;
14use ui::prelude::*;
15use util::ResultExt;
16
17pub struct OpenAiCompletionProvider {
18 api_key: Option<String>,
19 api_url: String,
20 default_model: OpenAiModel,
21 http_client: Arc<dyn HttpClient>,
22 low_speed_timeout: Option<Duration>,
23 settings_version: usize,
24}
25
26impl OpenAiCompletionProvider {
27 pub fn new(
28 default_model: OpenAiModel,
29 api_url: String,
30 http_client: Arc<dyn HttpClient>,
31 low_speed_timeout: Option<Duration>,
32 settings_version: usize,
33 ) -> Self {
34 Self {
35 api_key: None,
36 api_url,
37 default_model,
38 http_client,
39 low_speed_timeout,
40 settings_version,
41 }
42 }
43
44 pub fn update(
45 &mut self,
46 default_model: OpenAiModel,
47 api_url: String,
48 low_speed_timeout: Option<Duration>,
49 settings_version: usize,
50 ) {
51 self.default_model = default_model;
52 self.api_url = api_url;
53 self.low_speed_timeout = low_speed_timeout;
54 self.settings_version = settings_version;
55 }
56
57 pub fn settings_version(&self) -> usize {
58 self.settings_version
59 }
60
61 pub fn is_authenticated(&self) -> bool {
62 self.api_key.is_some()
63 }
64
65 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
66 if self.is_authenticated() {
67 Task::ready(Ok(()))
68 } else {
69 let api_url = self.api_url.clone();
70 cx.spawn(|mut cx| async move {
71 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
72 api_key
73 } else {
74 let (_, api_key) = cx
75 .update(|cx| cx.read_credentials(&api_url))?
76 .await?
77 .ok_or_else(|| anyhow!("credentials not found"))?;
78 String::from_utf8(api_key)?
79 };
80 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
81 if let CompletionProvider::OpenAi(provider) = provider {
82 provider.api_key = Some(api_key);
83 }
84 })
85 })
86 }
87 }
88
89 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
90 let delete_credentials = cx.delete_credentials(&self.api_url);
91 cx.spawn(|mut cx| async move {
92 delete_credentials.await.log_err();
93 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
94 if let CompletionProvider::OpenAi(provider) = provider {
95 provider.api_key = None;
96 }
97 })
98 })
99 }
100
101 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
102 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
103 .into()
104 }
105
106 pub fn default_model(&self) -> OpenAiModel {
107 self.default_model.clone()
108 }
109
110 pub fn count_tokens(
111 &self,
112 request: LanguageModelRequest,
113 cx: &AppContext,
114 ) -> BoxFuture<'static, Result<usize>> {
115 count_open_ai_tokens(request, cx.background_executor())
116 }
117
118 pub fn complete(
119 &self,
120 request: LanguageModelRequest,
121 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
122 let request = self.to_open_ai_request(request);
123
124 let http_client = self.http_client.clone();
125 let api_key = self.api_key.clone();
126 let api_url = self.api_url.clone();
127 let low_speed_timeout = self.low_speed_timeout;
128 async move {
129 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
130 let request = stream_completion(
131 http_client.as_ref(),
132 &api_url,
133 &api_key,
134 request,
135 low_speed_timeout,
136 );
137 let response = request.await?;
138 let stream = response
139 .filter_map(|response| async move {
140 match response {
141 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
142 Err(error) => Some(Err(error)),
143 }
144 })
145 .boxed();
146 Ok(stream)
147 }
148 .boxed()
149 }
150
151 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
152 let model = match request.model {
153 LanguageModel::ZedDotDev(_) => self.default_model(),
154 LanguageModel::OpenAi(model) => model,
155 };
156
157 Request {
158 model,
159 messages: request
160 .messages
161 .into_iter()
162 .map(|msg| match msg.role {
163 Role::User => RequestMessage::User {
164 content: msg.content,
165 },
166 Role::Assistant => RequestMessage::Assistant {
167 content: Some(msg.content),
168 tool_calls: Vec::new(),
169 },
170 Role::System => RequestMessage::System {
171 content: msg.content,
172 },
173 })
174 .collect(),
175 stream: true,
176 stop: request.stop,
177 temperature: request.temperature,
178 tools: Vec::new(),
179 tool_choice: None,
180 }
181 }
182}
183
184pub fn count_open_ai_tokens(
185 request: LanguageModelRequest,
186 background_executor: &gpui::BackgroundExecutor,
187) -> BoxFuture<'static, Result<usize>> {
188 background_executor
189 .spawn(async move {
190 let messages = request
191 .messages
192 .into_iter()
193 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
194 role: match message.role {
195 Role::User => "user".into(),
196 Role::Assistant => "assistant".into(),
197 Role::System => "system".into(),
198 },
199 content: Some(message.content),
200 name: None,
201 function_call: None,
202 })
203 .collect::<Vec<_>>();
204
205 tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
206 })
207 .boxed()
208}
209
210impl From<Role> for open_ai::Role {
211 fn from(val: Role) -> Self {
212 match val {
213 Role::User => OpenAiRole::User,
214 Role::Assistant => OpenAiRole::Assistant,
215 Role::System => OpenAiRole::System,
216 }
217 }
218}
219
220struct AuthenticationPrompt {
221 api_key: View<Editor>,
222 api_url: String,
223}
224
225impl AuthenticationPrompt {
226 fn new(api_url: String, cx: &mut WindowContext) -> Self {
227 Self {
228 api_key: cx.new_view(|cx| {
229 let mut editor = Editor::single_line(cx);
230 editor.set_placeholder_text(
231 "sk-000000000000000000000000000000000000000000000000",
232 cx,
233 );
234 editor
235 }),
236 api_url,
237 }
238 }
239
240 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
241 let api_key = self.api_key.read(cx).text(cx);
242 if api_key.is_empty() {
243 return;
244 }
245
246 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
247 cx.spawn(|_, mut cx| async move {
248 write_credentials.await?;
249 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
250 if let CompletionProvider::OpenAi(provider) = provider {
251 provider.api_key = Some(api_key);
252 }
253 })
254 })
255 .detach_and_log_err(cx);
256 }
257
258 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
259 let settings = ThemeSettings::get_global(cx);
260 let text_style = TextStyle {
261 color: cx.theme().colors().text,
262 font_family: settings.ui_font.family.clone(),
263 font_features: settings.ui_font.features.clone(),
264 font_size: rems(0.875).into(),
265 font_weight: FontWeight::NORMAL,
266 font_style: FontStyle::Normal,
267 line_height: relative(1.3),
268 background_color: None,
269 underline: None,
270 strikethrough: None,
271 white_space: WhiteSpace::Normal,
272 };
273 EditorElement::new(
274 &self.api_key,
275 EditorStyle {
276 background: cx.theme().colors().editor_background,
277 local_player: cx.theme().players().local(),
278 text: text_style,
279 ..Default::default()
280 },
281 )
282 }
283}
284
285impl Render for AuthenticationPrompt {
286 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
287 const INSTRUCTIONS: [&str; 6] = [
288 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
289 " - You can create an API key at: platform.openai.com/api-keys",
290 " - Make sure your OpenAI account has credits",
291 " - Having a subscription for another service like GitHub Copilot won't work.",
292 "",
293 "Paste your OpenAI API key below and hit enter to use the assistant:",
294 ];
295
296 v_flex()
297 .p_4()
298 .size_full()
299 .on_action(cx.listener(Self::save_api_key))
300 .children(
301 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
302 )
303 .child(
304 h_flex()
305 .w_full()
306 .my_2()
307 .px_2()
308 .py_1()
309 .bg(cx.theme().colors().editor_background)
310 .rounded_md()
311 .child(self.render_api_key_editor(cx)),
312 )
313 .child(
314 Label::new(
315 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
316 )
317 .size(LabelSize::Small),
318 )
319 .child(
320 h_flex()
321 .gap_2()
322 .child(Label::new("Click on").size(LabelSize::Small))
323 .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
324 .child(
325 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
326 ),
327 )
328 .into_any()
329 }
330}