1use crate::assistant_settings::CloudModel;
2use crate::{
3 assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
4};
5use anyhow::{anyhow, Result};
6use editor::{Editor, EditorElement, EditorStyle};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
8use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
9use http::HttpClient;
10use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
11use settings::Settings;
12use std::time::Duration;
13use std::{env, sync::Arc};
14use strum::IntoEnumIterator;
15use theme::ThemeSettings;
16use ui::prelude::*;
17use util::ResultExt;
18
19pub struct OpenAiCompletionProvider {
20 api_key: Option<String>,
21 api_url: String,
22 model: OpenAiModel,
23 http_client: Arc<dyn HttpClient>,
24 low_speed_timeout: Option<Duration>,
25 settings_version: usize,
26}
27
28impl OpenAiCompletionProvider {
29 pub fn new(
30 model: OpenAiModel,
31 api_url: String,
32 http_client: Arc<dyn HttpClient>,
33 low_speed_timeout: Option<Duration>,
34 settings_version: usize,
35 ) -> Self {
36 Self {
37 api_key: None,
38 api_url,
39 model,
40 http_client,
41 low_speed_timeout,
42 settings_version,
43 }
44 }
45
46 pub fn update(
47 &mut self,
48 model: OpenAiModel,
49 api_url: String,
50 low_speed_timeout: Option<Duration>,
51 settings_version: usize,
52 ) {
53 self.model = model;
54 self.api_url = api_url;
55 self.low_speed_timeout = low_speed_timeout;
56 self.settings_version = settings_version;
57 }
58
59 pub fn available_models(&self) -> impl Iterator<Item = OpenAiModel> {
60 OpenAiModel::iter()
61 }
62
63 pub fn settings_version(&self) -> usize {
64 self.settings_version
65 }
66
67 pub fn is_authenticated(&self) -> bool {
68 self.api_key.is_some()
69 }
70
71 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
72 if self.is_authenticated() {
73 Task::ready(Ok(()))
74 } else {
75 let api_url = self.api_url.clone();
76 cx.spawn(|mut cx| async move {
77 let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
78 api_key
79 } else {
80 let (_, api_key) = cx
81 .update(|cx| cx.read_credentials(&api_url))?
82 .await?
83 .ok_or_else(|| anyhow!("credentials not found"))?;
84 String::from_utf8(api_key)?
85 };
86 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
87 if let CompletionProvider::OpenAi(provider) = provider {
88 provider.api_key = Some(api_key);
89 }
90 })
91 })
92 }
93 }
94
95 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
96 let delete_credentials = cx.delete_credentials(&self.api_url);
97 cx.spawn(|mut cx| async move {
98 delete_credentials.await.log_err();
99 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
100 if let CompletionProvider::OpenAi(provider) = provider {
101 provider.api_key = None;
102 }
103 })
104 })
105 }
106
107 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
108 cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
109 .into()
110 }
111
112 pub fn model(&self) -> OpenAiModel {
113 self.model.clone()
114 }
115
116 pub fn count_tokens(
117 &self,
118 request: LanguageModelRequest,
119 cx: &AppContext,
120 ) -> BoxFuture<'static, Result<usize>> {
121 count_open_ai_tokens(request, cx.background_executor())
122 }
123
124 pub fn complete(
125 &self,
126 request: LanguageModelRequest,
127 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
128 let request = self.to_open_ai_request(request);
129
130 let http_client = self.http_client.clone();
131 let api_key = self.api_key.clone();
132 let api_url = self.api_url.clone();
133 let low_speed_timeout = self.low_speed_timeout;
134 async move {
135 let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
136 let request = stream_completion(
137 http_client.as_ref(),
138 &api_url,
139 &api_key,
140 request,
141 low_speed_timeout,
142 );
143 let response = request.await?;
144 let stream = response
145 .filter_map(|response| async move {
146 match response {
147 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
148 Err(error) => Some(Err(error)),
149 }
150 })
151 .boxed();
152 Ok(stream)
153 }
154 .boxed()
155 }
156
157 fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
158 let model = match request.model {
159 LanguageModel::OpenAi(model) => model,
160 _ => self.model(),
161 };
162
163 Request {
164 model,
165 messages: request
166 .messages
167 .into_iter()
168 .map(|msg| match msg.role {
169 Role::User => RequestMessage::User {
170 content: msg.content,
171 },
172 Role::Assistant => RequestMessage::Assistant {
173 content: Some(msg.content),
174 tool_calls: Vec::new(),
175 },
176 Role::System => RequestMessage::System {
177 content: msg.content,
178 },
179 })
180 .collect(),
181 stream: true,
182 stop: request.stop,
183 temperature: request.temperature,
184 tools: Vec::new(),
185 tool_choice: None,
186 }
187 }
188}
189
190pub fn count_open_ai_tokens(
191 request: LanguageModelRequest,
192 background_executor: &gpui::BackgroundExecutor,
193) -> BoxFuture<'static, Result<usize>> {
194 background_executor
195 .spawn(async move {
196 let messages = request
197 .messages
198 .into_iter()
199 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
200 role: match message.role {
201 Role::User => "user".into(),
202 Role::Assistant => "assistant".into(),
203 Role::System => "system".into(),
204 },
205 content: Some(message.content),
206 name: None,
207 function_call: None,
208 })
209 .collect::<Vec<_>>();
210
211 match request.model {
212 LanguageModel::Anthropic(_)
213 | LanguageModel::Cloud(CloudModel::Claude3Opus)
214 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
215 | LanguageModel::Cloud(CloudModel::Claude3Haiku) => {
216 // Tiktoken doesn't yet support these models, so we manually use the
217 // same tokenizer as GPT-4.
218 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
219 }
220 _ => tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages),
221 }
222 })
223 .boxed()
224}
225
226impl From<Role> for open_ai::Role {
227 fn from(val: Role) -> Self {
228 match val {
229 Role::User => OpenAiRole::User,
230 Role::Assistant => OpenAiRole::Assistant,
231 Role::System => OpenAiRole::System,
232 }
233 }
234}
235
236struct AuthenticationPrompt {
237 api_key: View<Editor>,
238 api_url: String,
239}
240
241impl AuthenticationPrompt {
242 fn new(api_url: String, cx: &mut WindowContext) -> Self {
243 Self {
244 api_key: cx.new_view(|cx| {
245 let mut editor = Editor::single_line(cx);
246 editor.set_placeholder_text(
247 "sk-000000000000000000000000000000000000000000000000",
248 cx,
249 );
250 editor
251 }),
252 api_url,
253 }
254 }
255
256 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
257 let api_key = self.api_key.read(cx).text(cx);
258 if api_key.is_empty() {
259 return;
260 }
261
262 let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
263 cx.spawn(|_, mut cx| async move {
264 write_credentials.await?;
265 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
266 if let CompletionProvider::OpenAi(provider) = provider {
267 provider.api_key = Some(api_key);
268 }
269 })
270 })
271 .detach_and_log_err(cx);
272 }
273
274 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
275 let settings = ThemeSettings::get_global(cx);
276 let text_style = TextStyle {
277 color: cx.theme().colors().text,
278 font_family: settings.ui_font.family.clone(),
279 font_features: settings.ui_font.features.clone(),
280 font_size: rems(0.875).into(),
281 font_weight: settings.ui_font.weight,
282 font_style: FontStyle::Normal,
283 line_height: relative(1.3),
284 background_color: None,
285 underline: None,
286 strikethrough: None,
287 white_space: WhiteSpace::Normal,
288 };
289 EditorElement::new(
290 &self.api_key,
291 EditorStyle {
292 background: cx.theme().colors().editor_background,
293 local_player: cx.theme().players().local(),
294 text: text_style,
295 ..Default::default()
296 },
297 )
298 }
299}
300
301impl Render for AuthenticationPrompt {
302 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
303 const INSTRUCTIONS: [&str; 6] = [
304 "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
305 " - You can create an API key at: platform.openai.com/api-keys",
306 " - Make sure your OpenAI account has credits",
307 " - Having a subscription for another service like GitHub Copilot won't work.",
308 "",
309 "Paste your OpenAI API key below and hit enter to use the assistant:",
310 ];
311
312 v_flex()
313 .p_4()
314 .size_full()
315 .on_action(cx.listener(Self::save_api_key))
316 .children(
317 INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
318 )
319 .child(
320 h_flex()
321 .w_full()
322 .my_2()
323 .px_2()
324 .py_1()
325 .bg(cx.theme().colors().editor_background)
326 .rounded_md()
327 .child(self.render_api_key_editor(cx)),
328 )
329 .child(
330 Label::new(
331 "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
332 )
333 .size(LabelSize::Small),
334 )
335 .child(
336 h_flex()
337 .gap_2()
338 .child(Label::new("Click on").size(LabelSize::Small))
339 .child(Icon::new(IconName::ZedAssistant).size(IconSize::XSmall))
340 .child(
341 Label::new("in the status bar to close this panel.").size(LabelSize::Small),
342 ),
343 )
344 .into_any()
345 }
346}