1use crate::CompletionProvider;
2use crate::LanguageModelCompletionProvider;
3use anyhow::{anyhow, Result};
4use editor::{Editor, EditorElement, EditorStyle};
5use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
6use gpui::{AnyView, AppContext, Task, TextStyle, View};
7use http::HttpClient;
8use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
9use open_ai::Model as OpenAiModel;
10use open_ai::{stream_completion, Request, RequestMessage};
11use settings::Settings;
12use std::time::Duration;
13use std::{env, sync::Arc};
14use strum::IntoEnumIterator;
15use theme::ThemeSettings;
16use ui::prelude::*;
17use util::ResultExt;
18
19pub struct OpenAiCompletionProvider {
20 api_key: Option<String>,
21 api_url: String,
22 model: OpenAiModel,
23 http_client: Arc<dyn HttpClient>,
24 low_speed_timeout: Option<Duration>,
25 settings_version: usize,
26 available_models_from_settings: Vec<OpenAiModel>,
27}
28
29impl OpenAiCompletionProvider {
30 pub fn new(
31 model: OpenAiModel,
32 api_url: String,
33 http_client: Arc<dyn HttpClient>,
34 low_speed_timeout: Option<Duration>,
35 settings_version: usize,
36 available_models_from_settings: Vec<OpenAiModel>,
37 ) -> Self {
38 Self {
39 api_key: None,
40 api_url,
41 model,
42 http_client,
43 low_speed_timeout,
44 settings_version,
45 available_models_from_settings,
46 }
47 }
48
49 pub fn update(
50 &mut self,
51 model: OpenAiModel,
52 api_url: String,
53 low_speed_timeout: Option<Duration>,
54 settings_version: usize,
55 ) {
56 self.model = model;
57 self.api_url = api_url;
58 self.low_speed_timeout = low_speed_timeout;
59 self.settings_version = settings_version;
60 }
61
62 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
63 let model = match request.model {
64 LanguageModel::OpenAi(model) => model,
65 _ => self.model.clone(),
66 };
67
68 Request {
69 model,
70 messages: request
71 .messages
72 .into_iter()
73 .map(|msg| match msg.role {
74 Role::User => RequestMessage::User {
75 content: msg.content,
76 },
77 Role::Assistant => RequestMessage::Assistant {
78 content: Some(msg.content),
79 tool_calls: Vec::new(),
80 },
81 Role::System => RequestMessage::System {
82 content: msg.content,
83 },
84 })
85 .collect(),
86 stream: true,
87 stop: request.stop,
88 temperature: request.temperature,
89 tools: Vec::new(),
90 tool_choice: None,
91 }
92 }
93}
94
95impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
96 fn available_models(&self) -> Vec<LanguageModel> {
97 if self.available_models_from_settings.is_empty() {
98 let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
99 vec![self.model.clone()]
100 } else {
101 OpenAiModel::iter()
102 .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
103 .collect()
104 };
105 available_models
106 .into_iter()
107 .map(LanguageModel::OpenAi)
108 .collect()
109 } else {
110 self.available_models_from_settings
111 .iter()
112 .cloned()
113 .map(LanguageModel::OpenAi)
114 .collect()
115 }
116 }
117
118 fn settings_version(&self) -> usize {
119 self.settings_version
120 }
121
122 fn is_authenticated(&self) -> bool {
123 self.api_key.is_some()
124 }
125
126 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
127 if self.is_authenticated() {
128 Task::ready(Ok(()))
129 } else {
130 let api_url = self.api_url.clone();
131 cx.spawn(|mut cx| async move {
132 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
133 api_key
134 } else {
135 let (_, api_key) = cx
136 .update(|cx| cx.read_credentials(&api_url))?
137 .await?
138 .ok_or_else(|| anyhow!("credentials not found"))?;
139 String::from_utf8(api_key)?
140 };
141 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
142 provider.update_current_as::<_, Self>(|provider| {
143 provider.api_key = Some(api_key);
144 });
145 })
146 })
147 }
148 }
149
150 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
151 let delete_credentials = cx.delete_credentials(&self.api_url);
152 cx.spawn(|mut cx| async move {
153 delete_credentials.await.log_err();
154 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
155 provider.update_current_as::<_, Self>(|provider| {
156 provider.api_key = None;
157 });
158 })
159 })
160 }
161
162 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
163 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
164 .into()
165 }
166
167 fn model(&self) -> LanguageModel {
168 LanguageModel::OpenAi(self.model.clone())
169 }
170
171 fn count_tokens(
172 &self,
173 request: LanguageModelRequest,
174 cx: &AppContext,
175 ) -> BoxFuture<'static, Result<usize>> {
176 count_open_ai_tokens(request, cx.background_executor())
177 }
178
179 fn stream_completion(
180 &self,
181 request: LanguageModelRequest,
182 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
183 let request = self.to_open_ai_request(request);
184
185 let http_client = self.http_client.clone();
186 let api_key = self.api_key.clone();
187 let api_url = self.api_url.clone();
188 let low_speed_timeout = self.low_speed_timeout;
189 async move {
190 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
191 let request = stream_completion(
192 http_client.as_ref(),
193 &api_url,
194 &api_key,
195 request,
196 low_speed_timeout,
197 );
198 let response = request.await?;
199 let stream = response
200 .filter_map(|response| async move {
201 match response {
202 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
203 Err(error) => Some(Err(error)),
204 }
205 })
206 .boxed();
207 Ok(stream)
208 }
209 .boxed()
210 }
211
212 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
213 self
214 }
215}
216
217pub fn count_open_ai_tokens(
218 request: LanguageModelRequest,
219 background_executor: &gpui::BackgroundExecutor,
220) -> BoxFuture<'static, Result<usize>> {
221 background_executor
222 .spawn(async move {
223 let messages = request
224 .messages
225 .into_iter()
226 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
227 role: match message.role {
228 Role::User => "user".into(),
229 Role::Assistant => "assistant".into(),
230 Role::System => "system".into(),
231 },
232 content: Some(message.content),
233 name: None,
234 function_call: None,
235 })
236 .collect::<Vec<_>>();
237
238 match request.model {
239 LanguageModel::Anthropic(_)
240 | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
241 | LanguageModel::Cloud(CloudModel::Claude3Opus)
242 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
243 | LanguageModel::Cloud(CloudModel::Claude3Haiku)
244 | LanguageModel::Cloud(CloudModel::Custom { .. })
245 | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
246 // Tiktoken doesn't yet support these models, so we manually use the
247 // same tokenizer as GPT-4.
248 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
249 }
250 _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
251 }
252 })
253 .boxed()
254}
255
256struct AuthenticationPrompt {
257 api_key: View<Editor>,
258 api_url: String,
259}
260
261impl AuthenticationPrompt {
262 fn new(api_url: String, cx: &mut WindowContext) -> Self {
263 Self {
264 api_key: cx.new_view(|cx| {
265 let mut editor = Editor::single_line(cx);
266 editor.set_placeholder_text(
267 "sk-000000000000000000000000000000000000000000000000",
268 cx,
269 );
270 editor
271 }),
272 api_url,
273 }
274 }
275
276 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
277 let api_key = self.api_key.read(cx).text(cx);
278 if api_key.is_empty() {
279 return;
280 }
281
282 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
283 cx.spawn(|_, mut cx| async move {
284 write_credentials.await?;
285 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
286 provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
287 provider.api_key = Some(api_key);
288 });
289 })
290 })
291 .detach_and_log_err(cx);
292 }
293
294 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
295 let settings = ThemeSettings::get_global(cx);
296 let text_style = TextStyle {
297 color: cx.theme().colors().text,
298 font_family: settings.ui_font.family.clone(),
299 font_features: settings.ui_font.features.clone(),
300 font_size: rems(0.875).into(),
301 font_weight: settings.ui_font.weight,
302 line_height: relative(1.3),
303 ..Default::default()
304 };
305 EditorElement::new(
306 &self.api_key,
307 EditorStyle {
308 background: cx.theme().colors().editor_background,
309 local_player: cx.theme().players().local(),
310 text: text_style,
311 ..Default::default()
312 },
313 )
314 }
315}
316
317impl Render for AuthenticationPrompt {
318 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
319 const INSTRUCTIONS: [&str; 6] = [
320 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
321 " - You can create an API key at: platform.openai.com/api-keys",
322 " - Make sure your OpenAI account has credits",
323 " - Having a subscription for another service like GitHub Copilot won't work.",
324 "",
325 "Paste your OpenAI API key below and hit enter to use the assistant:",
326 ];
327
328 v_flex()
329 .p_4()
330 .size_full()
331 .on_action(cx.listener(Self::save_api_key))
332 .children(
333 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
334 )
335 .child(
336 h_flex()
337 .w_full()
338 .my_2()
339 .px_2()
340 .py_1()
341 .bg(cx.theme().colors().editor_background)
342 .rounded_md()
343 .child(self.render_api_key_editor(cx)),
344 )
345 .child(
346 Label::new(
347 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
348 )
349 .size(LabelSize::Small),
350 )
351 .child(
352 h_flex()
353 .gap_2()
354 .child(Label::new("Click on").size(LabelSize::Small))
355 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
356 .child(
357 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
358 ),
359 )
360 .into_any()
361 }
362}