1use crate::assistant_settings::CloudModel;
2use crate::assistant_settings::{AssistantProvider, AssistantSettings};
3use crate::LanguageModelCompletionProvider;
4use crate::{
5 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
6};
7use anyhow::{anyhow, Result};
8use editor::{Editor, EditorElement, EditorStyle};
9use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
10use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
11use http::HttpClient;
12use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
13use settings::Settings;
14use std::time::Duration;
15use std::{env, sync::Arc};
16use strum::IntoEnumIterator;
17use theme::ThemeSettings;
18use ui::prelude::*;
19use util::ResultExt;
20
21pub struct OpenAiCompletionProvider {
22 api_key: Option<String>,
23 api_url: String,
24 model: OpenAiModel,
25 http_client: Arc<dyn HttpClient>,
26 low_speed_timeout: Option<Duration>,
27 settings_version: usize,
28}
29
30impl OpenAiCompletionProvider {
31 pub fn new(
32 model: OpenAiModel,
33 api_url: String,
34 http_client: Arc<dyn HttpClient>,
35 low_speed_timeout: Option<Duration>,
36 settings_version: usize,
37 ) -> Self {
38 Self {
39 api_key: None,
40 api_url,
41 model,
42 http_client,
43 low_speed_timeout,
44 settings_version,
45 }
46 }
47
48 pub fn update(
49 &mut self,
50 model: OpenAiModel,
51 api_url: String,
52 low_speed_timeout: Option<Duration>,
53 settings_version: usize,
54 ) {
55 self.model = model;
56 self.api_url = api_url;
57 self.low_speed_timeout = low_speed_timeout;
58 self.settings_version = settings_version;
59 }
60
61 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
62 let model = match request.model {
63 LanguageModel::OpenAi(model) => model,
64 _ => self.model.clone(),
65 };
66
67 Request {
68 model,
69 messages: request
70 .messages
71 .into_iter()
72 .map(|msg| match msg.role {
73 Role::User => RequestMessage::User {
74 content: msg.content,
75 },
76 Role::Assistant => RequestMessage::Assistant {
77 content: Some(msg.content),
78 tool_calls: Vec::new(),
79 },
80 Role::System => RequestMessage::System {
81 content: msg.content,
82 },
83 })
84 .collect(),
85 stream: true,
86 stop: request.stop,
87 temperature: request.temperature,
88 tools: Vec::new(),
89 tool_choice: None,
90 }
91 }
92}
93
94impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
95 fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
96 if let AssistantProvider::OpenAi {
97 available_models, ..
98 } = &AssistantSettings::get_global(cx).provider
99 {
100 if !available_models.is_empty() {
101 return available_models
102 .iter()
103 .cloned()
104 .map(LanguageModel::OpenAi)
105 .collect();
106 }
107 }
108 let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
109 vec![self.model.clone()]
110 } else {
111 OpenAiModel::iter()
112 .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
113 .collect()
114 };
115 available_models
116 .into_iter()
117 .map(LanguageModel::OpenAi)
118 .collect()
119 }
120
121 fn settings_version(&self) -> usize {
122 self.settings_version
123 }
124
125 fn is_authenticated(&self) -> bool {
126 self.api_key.is_some()
127 }
128
129 fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
130 if self.is_authenticated() {
131 Task::ready(Ok(()))
132 } else {
133 let api_url = self.api_url.clone();
134 cx.spawn(|mut cx| async move {
135 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
136 api_key
137 } else {
138 let (_, api_key) = cx
139 .update(|cx| cx.read_credentials(&api_url))?
140 .await?
141 .ok_or_else(|| anyhow!("credentials not found"))?;
142 String::from_utf8(api_key)?
143 };
144 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
145 provider.update_current_as::<_, Self>(|provider| {
146 provider.api_key = Some(api_key);
147 });
148 })
149 })
150 }
151 }
152
153 fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
154 let delete_credentials = cx.delete_credentials(&self.api_url);
155 cx.spawn(|mut cx| async move {
156 delete_credentials.await.log_err();
157 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
158 provider.update_current_as::<_, Self>(|provider| {
159 provider.api_key = None;
160 });
161 })
162 })
163 }
164
165 fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
166 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
167 .into()
168 }
169
170 fn model(&self) -> LanguageModel {
171 LanguageModel::OpenAi(self.model.clone())
172 }
173
174 fn count_tokens(
175 &self,
176 request: LanguageModelRequest,
177 cx: &AppContext,
178 ) -> BoxFuture<'static, Result<usize>> {
179 count_open_ai_tokens(request, cx.background_executor())
180 }
181
182 fn complete(
183 &self,
184 request: LanguageModelRequest,
185 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
186 let request = self.to_open_ai_request(request);
187
188 let http_client = self.http_client.clone();
189 let api_key = self.api_key.clone();
190 let api_url = self.api_url.clone();
191 let low_speed_timeout = self.low_speed_timeout;
192 async move {
193 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
194 let request = stream_completion(
195 http_client.as_ref(),
196 &api_url,
197 &api_key,
198 request,
199 low_speed_timeout,
200 );
201 let response = request.await?;
202 let stream = response
203 .filter_map(|response| async move {
204 match response {
205 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
206 Err(error) => Some(Err(error)),
207 }
208 })
209 .boxed();
210 Ok(stream)
211 }
212 .boxed()
213 }
214
215 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
216 self
217 }
218}
219
220pub fn count_open_ai_tokens(
221 request: LanguageModelRequest,
222 background_executor: &gpui::BackgroundExecutor,
223) -> BoxFuture<'static, Result<usize>> {
224 background_executor
225 .spawn(async move {
226 let messages = request
227 .messages
228 .into_iter()
229 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
230 role: match message.role {
231 Role::User => "user".into(),
232 Role::Assistant => "assistant".into(),
233 Role::System => "system".into(),
234 },
235 content: Some(message.content),
236 name: None,
237 function_call: None,
238 })
239 .collect::<Vec<_>>();
240
241 match request.model {
242 LanguageModel::Anthropic(_)
243 | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
244 | LanguageModel::Cloud(CloudModel::Claude3Opus)
245 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
246 | LanguageModel::Cloud(CloudModel::Claude3Haiku)
247 | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
248 // Tiktoken doesn't yet support these models, so we manually use the
249 // same tokenizer as GPT-4.
250 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
251 }
252 _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
253 }
254 })
255 .boxed()
256}
257
258impl From<Role> for open_ai::Role {
259 fn from(val: Role) -> Self {
260 match val {
261 Role::User => OpenAiRole::User,
262 Role::Assistant => OpenAiRole::Assistant,
263 Role::System => OpenAiRole::System,
264 }
265 }
266}
267
268struct AuthenticationPrompt {
269 api_key: View<Editor>,
270 api_url: String,
271}
272
273impl AuthenticationPrompt {
274 fn new(api_url: String, cx: &mut WindowContext) -> Self {
275 Self {
276 api_key: cx.new_view(|cx| {
277 let mut editor = Editor::single_line(cx);
278 editor.set_placeholder_text(
279 "sk-000000000000000000000000000000000000000000000000",
280 cx,
281 );
282 editor
283 }),
284 api_url,
285 }
286 }
287
288 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
289 let api_key = self.api_key.read(cx).text(cx);
290 if api_key.is_empty() {
291 return;
292 }
293
294 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
295 cx.spawn(|_, mut cx| async move {
296 write_credentials.await?;
297 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
298 provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
299 provider.api_key = Some(api_key);
300 });
301 })
302 })
303 .detach_and_log_err(cx);
304 }
305
306 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
307 let settings = ThemeSettings::get_global(cx);
308 let text_style = TextStyle {
309 color: cx.theme().colors().text,
310 font_family: settings.ui_font.family.clone(),
311 font_features: settings.ui_font.features.clone(),
312 font_size: rems(0.875).into(),
313 font_weight: settings.ui_font.weight,
314 font_style: FontStyle::Normal,
315 line_height: relative(1.3),
316 background_color: None,
317 underline: None,
318 strikethrough: None,
319 white_space: WhiteSpace::Normal,
320 };
321 EditorElement::new(
322 &self.api_key,
323 EditorStyle {
324 background: cx.theme().colors().editor_background,
325 local_player: cx.theme().players().local(),
326 text: text_style,
327 ..Default::default()
328 },
329 )
330 }
331}
332
333impl Render for AuthenticationPrompt {
334 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
335 const INSTRUCTIONS: [&str; 6] = [
336 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
337 " - You can create an API key at: platform.openai.com/api-keys",
338 " - Make sure your OpenAI account has credits",
339 " - Having a subscription for another service like GitHub Copilot won't work.",
340 "",
341 "Paste your OpenAI API key below and hit enter to use the assistant:",
342 ];
343
344 v_flex()
345 .p_4()
346 .size_full()
347 .on_action(cx.listener(Self::save_api_key))
348 .children(
349 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
350 )
351 .child(
352 h_flex()
353 .w_full()
354 .my_2()
355 .px_2()
356 .py_1()
357 .bg(cx.theme().colors().editor_background)
358 .rounded_md()
359 .child(self.render_api_key_editor(cx)),
360 )
361 .child(
362 Label::new(
363 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
364 )
365 .size(LabelSize::Small),
366 )
367 .child(
368 h_flex()
369 .gap_2()
370 .child(Label::new("Click on").size(LabelSize::Small))
371 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
372 .child(
373 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
374 ),
375 )
376 .into_any()
377 }
378}