1use anyhow::{anyhow, Result};
2use collections::BTreeMap;
3use editor::{Editor, EditorElement, EditorStyle};
4use futures::{future::BoxFuture, FutureExt, StreamExt};
5use gpui::{
6 AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
7 View, WhiteSpace,
8};
9use http_client::HttpClient;
10use open_ai::{
11 stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
12};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use settings::{Settings, SettingsStore};
16use std::{sync::Arc, time::Duration};
17use strum::IntoEnumIterator;
18use theme::ThemeSettings;
19use ui::{prelude::*, Icon, IconName, Tooltip};
20use util::ResultExt;
21
22use crate::LanguageModelCompletionEvent;
23use crate::{
24 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
25 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
26 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
27};
28
29const PROVIDER_ID: &str = "openai";
30const PROVIDER_NAME: &str = "OpenAI";
31
32#[derive(Default, Clone, Debug, PartialEq)]
33pub struct OpenAiSettings {
34 pub api_url: String,
35 pub low_speed_timeout: Option<Duration>,
36 pub available_models: Vec<AvailableModel>,
37 pub needs_setting_migration: bool,
38}
39
40#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
41pub struct AvailableModel {
42 pub name: String,
43 pub display_name: Option<String>,
44 pub max_tokens: usize,
45 pub max_output_tokens: Option<u32>,
46 pub max_completion_tokens: Option<u32>,
47}
48
49pub struct OpenAiLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Model<State>,
52}
53
54pub struct State {
55 api_key: Option<String>,
56 api_key_from_env: bool,
57 _subscription: Subscription,
58}
59
60const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
61
62impl State {
63 fn is_authenticated(&self) -> bool {
64 self.api_key.is_some()
65 }
66
67 fn reset_api_key(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
68 let settings = &AllLanguageModelSettings::get_global(cx).openai;
69 let delete_credentials = cx.delete_credentials(&settings.api_url);
70 cx.spawn(|this, mut cx| async move {
71 delete_credentials.await.log_err();
72 this.update(&mut cx, |this, cx| {
73 this.api_key = None;
74 this.api_key_from_env = false;
75 cx.notify();
76 })
77 })
78 }
79
80 fn set_api_key(&mut self, api_key: String, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
81 let settings = &AllLanguageModelSettings::get_global(cx).openai;
82 let write_credentials =
83 cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
84
85 cx.spawn(|this, mut cx| async move {
86 write_credentials.await?;
87 this.update(&mut cx, |this, cx| {
88 this.api_key = Some(api_key);
89 cx.notify();
90 })
91 })
92 }
93
94 fn authenticate(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
95 if self.is_authenticated() {
96 Task::ready(Ok(()))
97 } else {
98 let api_url = AllLanguageModelSettings::get_global(cx)
99 .openai
100 .api_url
101 .clone();
102 cx.spawn(|this, mut cx| async move {
103 let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
104 (api_key, true)
105 } else {
106 let (_, api_key) = cx
107 .update(|cx| cx.read_credentials(&api_url))?
108 .await?
109 .ok_or_else(|| anyhow!("credentials not found"))?;
110 (String::from_utf8(api_key)?, false)
111 };
112 this.update(&mut cx, |this, cx| {
113 this.api_key = Some(api_key);
114 this.api_key_from_env = from_env;
115 cx.notify();
116 })
117 })
118 }
119 }
120}
121
122impl OpenAiLanguageModelProvider {
123 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
124 let state = cx.new_model(|cx| State {
125 api_key: None,
126 api_key_from_env: false,
127 _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
128 cx.notify();
129 }),
130 });
131
132 Self { http_client, state }
133 }
134}
135
136impl LanguageModelProviderState for OpenAiLanguageModelProvider {
137 type ObservableEntity = State;
138
139 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
140 Some(self.state.clone())
141 }
142}
143
144impl LanguageModelProvider for OpenAiLanguageModelProvider {
145 fn id(&self) -> LanguageModelProviderId {
146 LanguageModelProviderId(PROVIDER_ID.into())
147 }
148
149 fn name(&self) -> LanguageModelProviderName {
150 LanguageModelProviderName(PROVIDER_NAME.into())
151 }
152
153 fn icon(&self) -> IconName {
154 IconName::AiOpenAi
155 }
156
157 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
158 let mut models = BTreeMap::default();
159
160 // Add base models from open_ai::Model::iter()
161 for model in open_ai::Model::iter() {
162 if !matches!(model, open_ai::Model::Custom { .. }) {
163 models.insert(model.id().to_string(), model);
164 }
165 }
166
167 // Override with available models from settings
168 for model in &AllLanguageModelSettings::get_global(cx)
169 .openai
170 .available_models
171 {
172 models.insert(
173 model.name.clone(),
174 open_ai::Model::Custom {
175 name: model.name.clone(),
176 display_name: model.display_name.clone(),
177 max_tokens: model.max_tokens,
178 max_output_tokens: model.max_output_tokens,
179 max_completion_tokens: model.max_completion_tokens,
180 },
181 );
182 }
183
184 models
185 .into_values()
186 .map(|model| {
187 Arc::new(OpenAiLanguageModel {
188 id: LanguageModelId::from(model.id().to_string()),
189 model,
190 state: self.state.clone(),
191 http_client: self.http_client.clone(),
192 request_limiter: RateLimiter::new(4),
193 }) as Arc<dyn LanguageModel>
194 })
195 .collect()
196 }
197
198 fn is_authenticated(&self, cx: &AppContext) -> bool {
199 self.state.read(cx).is_authenticated()
200 }
201
202 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
203 self.state.update(cx, |state, cx| state.authenticate(cx))
204 }
205
206 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
207 cx.new_view(|cx| ConfigurationView::new(self.state.clone(), cx))
208 .into()
209 }
210
211 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
212 self.state.update(cx, |state, cx| state.reset_api_key(cx))
213 }
214}
215
216pub struct OpenAiLanguageModel {
217 id: LanguageModelId,
218 model: open_ai::Model,
219 state: gpui::Model<State>,
220 http_client: Arc<dyn HttpClient>,
221 request_limiter: RateLimiter,
222}
223
224impl OpenAiLanguageModel {
225 fn stream_completion(
226 &self,
227 request: open_ai::Request,
228 cx: &AsyncAppContext,
229 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
230 {
231 let http_client = self.http_client.clone();
232 let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
233 let settings = &AllLanguageModelSettings::get_global(cx).openai;
234 (
235 state.api_key.clone(),
236 settings.api_url.clone(),
237 settings.low_speed_timeout,
238 )
239 }) else {
240 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
241 };
242
243 let future = self.request_limiter.stream(async move {
244 let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenAI API Key"))?;
245 let request = stream_completion(
246 http_client.as_ref(),
247 &api_url,
248 &api_key,
249 request,
250 low_speed_timeout,
251 );
252 let response = request.await?;
253 Ok(response)
254 });
255
256 async move { Ok(future.await?.boxed()) }.boxed()
257 }
258}
259
260impl LanguageModel for OpenAiLanguageModel {
261 fn id(&self) -> LanguageModelId {
262 self.id.clone()
263 }
264
265 fn name(&self) -> LanguageModelName {
266 LanguageModelName::from(self.model.display_name().to_string())
267 }
268
269 fn provider_id(&self) -> LanguageModelProviderId {
270 LanguageModelProviderId(PROVIDER_ID.into())
271 }
272
273 fn provider_name(&self) -> LanguageModelProviderName {
274 LanguageModelProviderName(PROVIDER_NAME.into())
275 }
276
277 fn telemetry_id(&self) -> String {
278 format!("openai/{}", self.model.id())
279 }
280
281 fn max_token_count(&self) -> usize {
282 self.model.max_token_count()
283 }
284
285 fn max_output_tokens(&self) -> Option<u32> {
286 self.model.max_output_tokens()
287 }
288
289 fn count_tokens(
290 &self,
291 request: LanguageModelRequest,
292 cx: &AppContext,
293 ) -> BoxFuture<'static, Result<usize>> {
294 count_open_ai_tokens(request, self.model.clone(), cx)
295 }
296
297 fn stream_completion(
298 &self,
299 request: LanguageModelRequest,
300 cx: &AsyncAppContext,
301 ) -> BoxFuture<
302 'static,
303 Result<futures::stream::BoxStream<'static, Result<LanguageModelCompletionEvent>>>,
304 > {
305 let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
306 let completions = self.stream_completion(request, cx);
307 async move {
308 Ok(open_ai::extract_text_from_events(completions.await?)
309 .map(|result| result.map(LanguageModelCompletionEvent::Text))
310 .boxed())
311 }
312 .boxed()
313 }
314
315 fn use_any_tool(
316 &self,
317 request: LanguageModelRequest,
318 tool_name: String,
319 tool_description: String,
320 schema: serde_json::Value,
321 cx: &AsyncAppContext,
322 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
323 let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
324 request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
325 function: FunctionDefinition {
326 name: tool_name.clone(),
327 description: None,
328 parameters: None,
329 },
330 }));
331 request.tools = vec![ToolDefinition::Function {
332 function: FunctionDefinition {
333 name: tool_name.clone(),
334 description: Some(tool_description),
335 parameters: Some(schema),
336 },
337 }];
338
339 let response = self.stream_completion(request, cx);
340 self.request_limiter
341 .run(async move {
342 let response = response.await?;
343 Ok(
344 open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
345 .await?
346 .boxed(),
347 )
348 })
349 .boxed()
350 }
351}
352
353pub fn count_open_ai_tokens(
354 request: LanguageModelRequest,
355 model: open_ai::Model,
356 cx: &AppContext,
357) -> BoxFuture<'static, Result<usize>> {
358 cx.background_executor()
359 .spawn(async move {
360 let messages = request
361 .messages
362 .into_iter()
363 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
364 role: match message.role {
365 Role::User => "user".into(),
366 Role::Assistant => "assistant".into(),
367 Role::System => "system".into(),
368 },
369 content: Some(message.string_contents()),
370 name: None,
371 function_call: None,
372 })
373 .collect::<Vec<_>>();
374
375 match model {
376 open_ai::Model::Custom { .. }
377 | open_ai::Model::O1Mini
378 | open_ai::Model::O1Preview => {
379 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages)
380 }
381 _ => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
382 }
383 })
384 .boxed()
385}
386
387struct ConfigurationView {
388 api_key_editor: View<Editor>,
389 state: gpui::Model<State>,
390 load_credentials_task: Option<Task<()>>,
391}
392
393impl ConfigurationView {
394 fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
395 let api_key_editor = cx.new_view(|cx| {
396 let mut editor = Editor::single_line(cx);
397 editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
398 editor
399 });
400
401 cx.observe(&state, |_, _, cx| {
402 cx.notify();
403 })
404 .detach();
405
406 let load_credentials_task = Some(cx.spawn({
407 let state = state.clone();
408 |this, mut cx| async move {
409 if let Some(task) = state
410 .update(&mut cx, |state, cx| state.authenticate(cx))
411 .log_err()
412 {
413 // We don't log an error, because "not signed in" is also an error.
414 let _ = task.await;
415 }
416
417 this.update(&mut cx, |this, cx| {
418 this.load_credentials_task = None;
419 cx.notify();
420 })
421 .log_err();
422 }
423 }));
424
425 Self {
426 api_key_editor,
427 state,
428 load_credentials_task,
429 }
430 }
431
432 fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
433 let api_key = self.api_key_editor.read(cx).text(cx);
434 if api_key.is_empty() {
435 return;
436 }
437
438 let state = self.state.clone();
439 cx.spawn(|_, mut cx| async move {
440 state
441 .update(&mut cx, |state, cx| state.set_api_key(api_key, cx))?
442 .await
443 })
444 .detach_and_log_err(cx);
445
446 cx.notify();
447 }
448
449 fn reset_api_key(&mut self, cx: &mut ViewContext<Self>) {
450 self.api_key_editor
451 .update(cx, |editor, cx| editor.set_text("", cx));
452
453 let state = self.state.clone();
454 cx.spawn(|_, mut cx| async move {
455 state
456 .update(&mut cx, |state, cx| state.reset_api_key(cx))?
457 .await
458 })
459 .detach_and_log_err(cx);
460
461 cx.notify();
462 }
463
464 fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
465 let settings = ThemeSettings::get_global(cx);
466 let text_style = TextStyle {
467 color: cx.theme().colors().text,
468 font_family: settings.ui_font.family.clone(),
469 font_features: settings.ui_font.features.clone(),
470 font_fallbacks: settings.ui_font.fallbacks.clone(),
471 font_size: rems(0.875).into(),
472 font_weight: settings.ui_font.weight,
473 font_style: FontStyle::Normal,
474 line_height: relative(1.3),
475 background_color: None,
476 underline: None,
477 strikethrough: None,
478 white_space: WhiteSpace::Normal,
479 truncate: None,
480 };
481 EditorElement::new(
482 &self.api_key_editor,
483 EditorStyle {
484 background: cx.theme().colors().editor_background,
485 local_player: cx.theme().players().local(),
486 text: text_style,
487 ..Default::default()
488 },
489 )
490 }
491
492 fn should_render_editor(&self, cx: &mut ViewContext<Self>) -> bool {
493 !self.state.read(cx).is_authenticated()
494 }
495}
496
497impl Render for ConfigurationView {
498 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
499 const OPENAI_CONSOLE_URL: &str = "https://platform.openai.com/api-keys";
500 const INSTRUCTIONS: [&str; 4] = [
501 "To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:",
502 " - Create one by visiting:",
503 " - Ensure your OpenAI account has credits",
504 " - Paste your API key below and hit enter to start using the assistant",
505 ];
506
507 let env_var_set = self.state.read(cx).api_key_from_env;
508
509 if self.load_credentials_task.is_some() {
510 div().child(Label::new("Loading credentials...")).into_any()
511 } else if self.should_render_editor(cx) {
512 v_flex()
513 .size_full()
514 .on_action(cx.listener(Self::save_api_key))
515 .child(Label::new(INSTRUCTIONS[0]))
516 .child(h_flex().child(Label::new(INSTRUCTIONS[1])).child(
517 Button::new("openai_console", OPENAI_CONSOLE_URL)
518 .style(ButtonStyle::Subtle)
519 .icon(IconName::ExternalLink)
520 .icon_size(IconSize::XSmall)
521 .icon_color(Color::Muted)
522 .on_click(move |_, cx| cx.open_url(OPENAI_CONSOLE_URL))
523 )
524 )
525 .children(
526 (2..INSTRUCTIONS.len()).map(|n|
527 Label::new(INSTRUCTIONS[n])).collect::<Vec<_>>())
528 .child(
529 h_flex()
530 .w_full()
531 .my_2()
532 .px_2()
533 .py_1()
534 .bg(cx.theme().colors().editor_background)
535 .rounded_md()
536 .child(self.render_api_key_editor(cx)),
537 )
538 .child(
539 Label::new(
540 format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
541 )
542 .size(LabelSize::Small),
543 )
544 .child(
545 Label::new(
546 "Note that having a subscription for another service like GitHub Copilot won't work.".to_string(),
547 )
548 .size(LabelSize::Small),
549 )
550 .into_any()
551 } else {
552 h_flex()
553 .size_full()
554 .justify_between()
555 .child(
556 h_flex()
557 .gap_1()
558 .child(Icon::new(IconName::Check).color(Color::Success))
559 .child(Label::new(if env_var_set {
560 format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
561 } else {
562 "API key configured.".to_string()
563 })),
564 )
565 .child(
566 Button::new("reset-key", "Reset key")
567 .icon(Some(IconName::Trash))
568 .icon_size(IconSize::Small)
569 .icon_position(IconPosition::Start)
570 .disabled(env_var_set)
571 .when(env_var_set, |this| {
572 this.tooltip(|cx| Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable."), cx))
573 })
574 .on_click(cx.listener(|this, _, cx| this.reset_api_key(cx))),
575 )
576 .into_any()
577 }
578 }
579}