1use crate::assistant_settings::CloudModel;
2use crate::assistant_settings::{AssistantProvider, AssistantSettings};
3use crate::{
4 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
5};
6use anyhow::{anyhow, Result};
7use editor::{Editor, EditorElement, EditorStyle};
8use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
9use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
10use http::HttpClient;
11use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
12use settings::Settings;
13use std::time::Duration;
14use std::{env, sync::Arc};
15use strum::IntoEnumIterator;
16use theme::ThemeSettings;
17use ui::prelude::*;
18use util::ResultExt;
19
20pub struct OpenAiCompletionProvider {
21 api_key: Option<String>,
22 api_url: String,
23 model: OpenAiModel,
24 http_client: Arc<dyn HttpClient>,
25 low_speed_timeout: Option<Duration>,
26 settings_version: usize,
27}
28
29impl OpenAiCompletionProvider {
30 pub fn new(
31 model: OpenAiModel,
32 api_url: String,
33 http_client: Arc<dyn HttpClient>,
34 low_speed_timeout: Option<Duration>,
35 settings_version: usize,
36 ) -> Self {
37 Self {
38 api_key: None,
39 api_url,
40 model,
41 http_client,
42 low_speed_timeout,
43 settings_version,
44 }
45 }
46
47 pub fn update(
48 &mut self,
49 model: OpenAiModel,
50 api_url: String,
51 low_speed_timeout: Option<Duration>,
52 settings_version: usize,
53 ) {
54 self.model = model;
55 self.api_url = api_url;
56 self.low_speed_timeout = low_speed_timeout;
57 self.settings_version = settings_version;
58 }
59
60 pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
61 if let AssistantProvider::OpenAi {
62 available_models, ..
63 } = &AssistantSettings::get_global(cx).provider
64 {
65 if !available_models.is_empty() {
66 // available_models is set, just return it
67 return available_models.clone().into_iter();
68 }
69 }
70 let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
71 // available_models is not set but the default model is set to custom, only show custom
72 vec![self.model.clone()]
73 } else {
74 // default case, use all models except custom
75 OpenAiModel::iter()
76 .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
77 .collect()
78 };
79 available_models.into_iter()
80 }
81
82 pub fn settings_version(&self) -> usize {
83 self.settings_version
84 }
85
86 pub fn is_authenticated(&self) -> bool {
87 self.api_key.is_some()
88 }
89
90 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
91 if self.is_authenticated() {
92 Task::ready(Ok(()))
93 } else {
94 let api_url = self.api_url.clone();
95 cx.spawn(|mut cx| async move {
96 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
97 api_key
98 } else {
99 let (_, api_key) = cx
100 .update(|cx| cx.read_credentials(&api_url))?
101 .await?
102 .ok_or_else(|| anyhow!("credentials not found"))?;
103 String::from_utf8(api_key)?
104 };
105 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
106 if let CompletionProvider::OpenAi(provider) = provider {
107 provider.api_key = Some(api_key);
108 }
109 })
110 })
111 }
112 }
113
114 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
115 let delete_credentials = cx.delete_credentials(&self.api_url);
116 cx.spawn(|mut cx| async move {
117 delete_credentials.await.log_err();
118 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
119 if let CompletionProvider::OpenAi(provider) = provider {
120 provider.api_key = None;
121 }
122 })
123 })
124 }
125
126 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
127 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
128 .into()
129 }
130
131 pub fn model(&self) -> OpenAiModel {
132 self.model.clone()
133 }
134
135 pub fn count_tokens(
136 &self,
137 request: LanguageModelRequest,
138 cx: &AppContext,
139 ) -> BoxFuture<'static, Result<usize>> {
140 count_open_ai_tokens(request, cx.background_executor())
141 }
142
143 pub fn complete(
144 &self,
145 request: LanguageModelRequest,
146 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
147 let request = self.to_open_ai_request(request);
148
149 let http_client = self.http_client.clone();
150 let api_key = self.api_key.clone();
151 let api_url = self.api_url.clone();
152 let low_speed_timeout = self.low_speed_timeout;
153 async move {
154 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
155 let request = stream_completion(
156 http_client.as_ref(),
157 &api_url,
158 &api_key,
159 request,
160 low_speed_timeout,
161 );
162 let response = request.await?;
163 let stream = response
164 .filter_map(|response| async move {
165 match response {
166 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
167 Err(error) => Some(Err(error)),
168 }
169 })
170 .boxed();
171 Ok(stream)
172 }
173 .boxed()
174 }
175
176 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
177 let model = match request.model {
178 LanguageModel::OpenAi(model) => model,
179 _ => self.model(),
180 };
181
182 Request {
183 model,
184 messages: request
185 .messages
186 .into_iter()
187 .map(|msg| match msg.role {
188 Role::User => RequestMessage::User {
189 content: msg.content,
190 },
191 Role::Assistant => RequestMessage::Assistant {
192 content: Some(msg.content),
193 tool_calls: Vec::new(),
194 },
195 Role::System => RequestMessage::System {
196 content: msg.content,
197 },
198 })
199 .collect(),
200 stream: true,
201 stop: request.stop,
202 temperature: request.temperature,
203 tools: Vec::new(),
204 tool_choice: None,
205 }
206 }
207}
208
209pub fn count_open_ai_tokens(
210 request: LanguageModelRequest,
211 background_executor: &gpui::BackgroundExecutor,
212) -> BoxFuture<'static, Result<usize>> {
213 background_executor
214 .spawn(async move {
215 let messages = request
216 .messages
217 .into_iter()
218 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
219 role: match message.role {
220 Role::User => "user".into(),
221 Role::Assistant => "assistant".into(),
222 Role::System => "system".into(),
223 },
224 content: Some(message.content),
225 name: None,
226 function_call: None,
227 })
228 .collect::<Vec<_>>();
229
230 match request.model {
231 LanguageModel::Anthropic(_)
232 | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
233 | LanguageModel::Cloud(CloudModel::Claude3Opus)
234 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
235 | LanguageModel::Cloud(CloudModel::Claude3Haiku)
236 | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
237 // Tiktoken doesn't yet support these models, so we manually use the
238 // same tokenizer as GPT-4.
239 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
240 }
241 _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
242 }
243 })
244 .boxed()
245}
246
247impl From<Role> for open_ai::Role {
248 fn from(val: Role) -> Self {
249 match val {
250 Role::User => OpenAiRole::User,
251 Role::Assistant => OpenAiRole::Assistant,
252 Role::System => OpenAiRole::System,
253 }
254 }
255}
256
257struct AuthenticationPrompt {
258 api_key: View<Editor>,
259 api_url: String,
260}
261
262impl AuthenticationPrompt {
263 fn new(api_url: String, cx: &mut WindowContext) -> Self {
264 Self {
265 api_key: cx.new_view(|cx| {
266 let mut editor = Editor::single_line(cx);
267 editor.set_placeholder_text(
268 "sk-000000000000000000000000000000000000000000000000",
269 cx,
270 );
271 editor
272 }),
273 api_url,
274 }
275 }
276
277 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
278 let api_key = self.api_key.read(cx).text(cx);
279 if api_key.is_empty() {
280 return;
281 }
282
283 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
284 cx.spawn(|_, mut cx| async move {
285 write_credentials.await?;
286 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
287 if let CompletionProvider::OpenAi(provider) = provider {
288 provider.api_key = Some(api_key);
289 }
290 })
291 })
292 .detach_and_log_err(cx);
293 }
294
295 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
296 let settings = ThemeSettings::get_global(cx);
297 let text_style = TextStyle {
298 color: cx.theme().colors().text,
299 font_family: settings.ui_font.family.clone(),
300 font_features: settings.ui_font.features.clone(),
301 font_size: rems(0.875).into(),
302 font_weight: settings.ui_font.weight,
303 font_style: FontStyle::Normal,
304 line_height: relative(1.3),
305 background_color: None,
306 underline: None,
307 strikethrough: None,
308 white_space: WhiteSpace::Normal,
309 };
310 EditorElement::new(
311 &self.api_key,
312 EditorStyle {
313 background: cx.theme().colors().editor_background,
314 local_player: cx.theme().players().local(),
315 text: text_style,
316 ..Default::default()
317 },
318 )
319 }
320}
321
322impl Render for AuthenticationPrompt {
323 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
324 const INSTRUCTIONS: [&str; 6] = [
325 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
326 " - You can create an API key at: platform.openai.com/api-keys",
327 " - Make sure your OpenAI account has credits",
328 " - Having a subscription for another service like GitHub Copilot won't work.",
329 "",
330 "Paste your OpenAI API key below and hit enter to use the assistant:",
331 ];
332
333 v_flex()
334 .p_4()
335 .size_full()
336 .on_action(cx.listener(Self::save_api_key))
337 .children(
338 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
339 )
340 .child(
341 h_flex()
342 .w_full()
343 .my_2()
344 .px_2()
345 .py_1()
346 .bg(cx.theme().colors().editor_background)
347 .rounded_md()
348 .child(self.render_api_key_editor(cx)),
349 )
350 .child(
351 Label::new(
352 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
353 )
354 .size(LabelSize::Small),
355 )
356 .child(
357 h_flex()
358 .gap_2()
359 .child(Label::new("Click on").size(LabelSize::Small))
360 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
361 .child(
362 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
363 ),
364 )
365 .into_any()
366 }
367}