1use crate::assistant_settings::CloudModel;
2use crate::{
3 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
4};
5use anyhow::{anyhow, Result};
6use editor::{Editor, EditorElement, EditorStyle};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
9use http::HttpClient;
10use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
11use settings::Settings;
12use std::time::Duration;
13use std::{env, sync::Arc};
14use strum::IntoEnumIterator;
15use theme::ThemeSettings;
16use ui::prelude::*;
17use util::ResultExt;
18
19pub struct OpenAiCompletionProvider {
20 api_key: Option<String>,
21 api_url: String,
22 model: OpenAiModel,
23 http_client: Arc<dyn HttpClient>,
24 low_speed_timeout: Option<Duration>,
25 settings_version: usize,
26}
27
28impl OpenAiCompletionProvider {
29 pub fn new(
30 model: OpenAiModel,
31 api_url: String,
32 http_client: Arc<dyn HttpClient>,
33 low_speed_timeout: Option<Duration>,
34 settings_version: usize,
35 ) -> Self {
36 Self {
37 api_key: None,
38 api_url,
39 model,
40 http_client,
41 low_speed_timeout,
42 settings_version,
43 }
44 }
45
46 pub fn update(
47 &mut self,
48 model: OpenAiModel,
49 api_url: String,
50 low_speed_timeout: Option<Duration>,
51 settings_version: usize,
52 ) {
53 self.model = model;
54 self.api_url = api_url;
55 self.low_speed_timeout = low_speed_timeout;
56 self.settings_version = settings_version;
57 }
58
59 pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
60 OpenAiModel::iter()
61 }
62
63 pub fn settings_version(&self) -> usize {
64 self.settings_version
65 }
66
67 pub fn is_authenticated(&self) -> bool {
68 self.api_key.is_some()
69 }
70
71 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
72 if self.is_authenticated() {
73 Task::ready(Ok(()))
74 } else {
75 let api_url = self.api_url.clone();
76 cx.spawn(|mut cx| async move {
77 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
78 api_key
79 } else {
80 let (_, api_key) = cx
81 .update(|cx| cx.read_credentials(&api_url))?
82 .await?
83 .ok_or_else(|| anyhow!("credentials not found"))?;
84 String::from_utf8(api_key)?
85 };
86 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
87 if let CompletionProvider::OpenAi(provider) = provider {
88 provider.api_key = Some(api_key);
89 }
90 })
91 })
92 }
93 }
94
95 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
96 let delete_credentials = cx.delete_credentials(&self.api_url);
97 cx.spawn(|mut cx| async move {
98 delete_credentials.await.log_err();
99 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
100 if let CompletionProvider::OpenAi(provider) = provider {
101 provider.api_key = None;
102 }
103 })
104 })
105 }
106
107 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
108 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
109 .into()
110 }
111
112 pub fn model(&self) -> OpenAiModel {
113 self.model.clone()
114 }
115
116 pub fn count_tokens(
117 &self,
118 request: LanguageModelRequest,
119 cx: &AppContext,
120 ) -> BoxFuture<'static, Result<usize>> {
121 count_open_ai_tokens(request, cx.background_executor())
122 }
123
124 pub fn complete(
125 &self,
126 request: LanguageModelRequest,
127 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
128 let request = self.to_open_ai_request(request);
129
130 let http_client = self.http_client.clone();
131 let api_key = self.api_key.clone();
132 let api_url = self.api_url.clone();
133 let low_speed_timeout = self.low_speed_timeout;
134 async move {
135 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
136 let request = stream_completion(
137 http_client.as_ref(),
138 &api_url,
139 &api_key,
140 request,
141 low_speed_timeout,
142 );
143 let response = request.await?;
144 let stream = response
145 .filter_map(|response| async move {
146 match response {
147 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
148 Err(error) => Some(Err(error)),
149 }
150 })
151 .boxed();
152 Ok(stream)
153 }
154 .boxed()
155 }
156
157 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
158 let model = match request.model {
159 LanguageModel::OpenAi(model) => model,
160 _ => self.model(),
161 };
162
163 Request {
164 model,
165 messages: request
166 .messages
167 .into_iter()
168 .map(|msg| match msg.role {
169 Role::User => RequestMessage::User {
170 content: msg.content,
171 },
172 Role::Assistant => RequestMessage::Assistant {
173 content: Some(msg.content),
174 tool_calls: Vec::new(),
175 },
176 Role::System => RequestMessage::System {
177 content: msg.content,
178 },
179 })
180 .collect(),
181 stream: true,
182 stop: request.stop,
183 temperature: request.temperature,
184 tools: Vec::new(),
185 tool_choice: None,
186 }
187 }
188}
189
190pub fn count_open_ai_tokens(
191 request: LanguageModelRequest,
192 background_executor: &gpui::BackgroundExecutor,
193) -> BoxFuture<'static, Result<usize>> {
194 background_executor
195 .spawn(async move {
196 let messages = request
197 .messages
198 .into_iter()
199 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
200 role: match message.role {
201 Role::User => "user".into(),
202 Role::Assistant => "assistant".into(),
203 Role::System => "system".into(),
204 },
205 content: Some(message.content),
206 name: None,
207 function_call: None,
208 })
209 .collect::<Vec<_>>();
210
211 match request.model {
212 LanguageModel::Anthropic(_)
213 | LanguageModel::Cloud(CloudModel::Claude3_5Sonnet)
214 | LanguageModel::Cloud(CloudModel::Claude3Opus)
215 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
216 | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
217 // Tiktoken doesn't yet support these models, so we manually use the
218 // same tokenizer as GPT-4.
219 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
220 }
221 _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
222 }
223 })
224 .boxed()
225}
226
227impl From<Role> for open_ai::Role {
228 fn from(val: Role) -> Self {
229 match val {
230 Role::User => OpenAiRole::User,
231 Role::Assistant => OpenAiRole::Assistant,
232 Role::System => OpenAiRole::System,
233 }
234 }
235}
236
237struct AuthenticationPrompt {
238 api_key: View<Editor>,
239 api_url: String,
240}
241
242impl AuthenticationPrompt {
243 fn new(api_url: String, cx: &mut WindowContext) -> Self {
244 Self {
245 api_key: cx.new_view(|cx| {
246 let mut editor = Editor::single_line(cx);
247 editor.set_placeholder_text(
248 "sk-000000000000000000000000000000000000000000000000",
249 cx,
250 );
251 editor
252 }),
253 api_url,
254 }
255 }
256
257 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
258 let api_key = self.api_key.read(cx).text(cx);
259 if api_key.is_empty() {
260 return;
261 }
262
263 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
264 cx.spawn(|_, mut cx| async move {
265 write_credentials.await?;
266 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
267 if let CompletionProvider::OpenAi(provider) = provider {
268 provider.api_key = Some(api_key);
269 }
270 })
271 })
272 .detach_and_log_err(cx);
273 }
274
275 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
276 let settings = ThemeSettings::get_global(cx);
277 let text_style = TextStyle {
278 color: cx.theme().colors().text,
279 font_family: settings.ui_font.family.clone(),
280 font_features: settings.ui_font.features.clone(),
281 font_size: rems(0.875).into(),
282 font_weight: settings.ui_font.weight,
283 font_style: FontStyle::Normal,
284 line_height: relative(1.3),
285 background_color: None,
286 underline: None,
287 strikethrough: None,
288 white_space: WhiteSpace::Normal,
289 };
290 EditorElement::new(
291 &self.api_key,
292 EditorStyle {
293 background: cx.theme().colors().editor_background,
294 local_player: cx.theme().players().local(),
295 text: text_style,
296 ..Default::default()
297 },
298 )
299 }
300}
301
302impl Render for AuthenticationPrompt {
303 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
304 const INSTRUCTIONS: [&str; 6] = [
305 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
306 " - You can create an API key at: platform.openai.com/api-keys",
307 " - Make sure your OpenAI account has credits",
308 " - Having a subscription for another service like GitHub Copilot won't work.",
309 "",
310 "Paste your OpenAI API key below and hit enter to use the assistant:",
311 ];
312
313 v_flex()
314 .p_4()
315 .size_full()
316 .on_action(cx.listener(Self::save_api_key))
317 .children(
318 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
319 )
320 .child(
321 h_flex()
322 .w_full()
323 .my_2()
324 .px_2()
325 .py_1()
326 .bg(cx.theme().colors().editor_background)
327 .rounded_md()
328 .child(self.render_api_key_editor(cx)),
329 )
330 .child(
331 Label::new(
332 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
333 )
334 .size(LabelSize::Small),
335 )
336 .child(
337 h_flex()
338 .gap_2()
339 .child(Label::new("Click on").size(LabelSize::Small))
340 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
341 .child(
342 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
343 ),
344 )
345 .into_any()
346 }
347}