1use anyhow::{Context as _, Result, anyhow};
2use credentials_provider::CredentialsProvider;
3
4use convert_case::{Case, Casing};
5use futures::{FutureExt, StreamExt, future::BoxFuture};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
10 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
11 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
12 LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
13};
14use menu;
15use open_ai::{ResponseStreamEvent, stream_completion};
16use settings::{Settings, SettingsStore};
17use std::sync::Arc;
18
19use ui::{ElevationIndex, Tooltip, prelude::*};
20use ui_input::SingleLineInput;
21use util::ResultExt;
22
23use crate::AllLanguageModelSettings;
24use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
25pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
26pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
27
28#[derive(Default, Clone, Debug, PartialEq)]
29pub struct OpenAiCompatibleSettings {
30 pub api_url: String,
31 pub available_models: Vec<AvailableModel>,
32}
33
34pub struct OpenAiCompatibleLanguageModelProvider {
35 id: LanguageModelProviderId,
36 name: LanguageModelProviderName,
37 http_client: Arc<dyn HttpClient>,
38 state: gpui::Entity<State>,
39}
40
41pub struct State {
42 id: Arc<str>,
43 env_var_name: Arc<str>,
44 api_key: Option<String>,
45 api_key_from_env: bool,
46 settings: OpenAiCompatibleSettings,
47 _subscription: Subscription,
48}
49
50impl State {
51 fn is_authenticated(&self) -> bool {
52 self.api_key.is_some()
53 }
54
55 fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
56 let credentials_provider = <dyn CredentialsProvider>::global(cx);
57 let api_url = self.settings.api_url.clone();
58 cx.spawn(async move |this, cx| {
59 credentials_provider
60 .delete_credentials(&api_url, cx)
61 .await
62 .log_err();
63 this.update(cx, |this, cx| {
64 this.api_key = None;
65 this.api_key_from_env = false;
66 cx.notify();
67 })
68 })
69 }
70
71 fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
72 let credentials_provider = <dyn CredentialsProvider>::global(cx);
73 let api_url = self.settings.api_url.clone();
74 cx.spawn(async move |this, cx| {
75 credentials_provider
76 .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
77 .await
78 .log_err();
79 this.update(cx, |this, cx| {
80 this.api_key = Some(api_key);
81 cx.notify();
82 })
83 })
84 }
85
86 fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
87 let credentials_provider = <dyn CredentialsProvider>::global(cx);
88 let env_var_name = self.env_var_name.clone();
89 let api_url = self.settings.api_url.clone();
90 cx.spawn(async move |this, cx| {
91 let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
92 (api_key, true)
93 } else {
94 let (_, api_key) = credentials_provider
95 .read_credentials(&api_url, cx)
96 .await?
97 .ok_or(AuthenticateError::CredentialsNotFound)?;
98 (
99 String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
100 false,
101 )
102 };
103 this.update(cx, |this, cx| {
104 this.api_key = Some(api_key);
105 this.api_key_from_env = from_env;
106 cx.notify();
107 })?;
108
109 Ok(())
110 })
111 }
112
113 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
114 if self.is_authenticated() {
115 return Task::ready(Ok(()));
116 }
117
118 self.get_api_key(cx)
119 }
120}
121
122impl OpenAiCompatibleLanguageModelProvider {
123 pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
124 fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
125 AllLanguageModelSettings::get_global(cx)
126 .openai_compatible
127 .get(id)
128 }
129
130 let state = cx.new(|cx| State {
131 id: id.clone(),
132 env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
133 settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
134 api_key: None,
135 api_key_from_env: false,
136 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
137 let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
138 return;
139 };
140 if &this.settings != &settings {
141 if settings.api_url != this.settings.api_url && !this.api_key_from_env {
142 let spawn_task = cx.spawn(async move |handle, cx| {
143 if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
144 if let Err(_) = task.await {
145 handle
146 .update(cx, |this, _| {
147 this.api_key = None;
148 this.api_key_from_env = false;
149 })
150 .ok();
151 }
152 }
153 });
154 spawn_task.detach();
155 }
156
157 this.settings = settings;
158 cx.notify();
159 }
160 }),
161 });
162
163 Self {
164 id: id.clone().into(),
165 name: id.into(),
166 http_client,
167 state,
168 }
169 }
170
171 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
172 Arc::new(OpenAiCompatibleLanguageModel {
173 id: LanguageModelId::from(model.name.clone()),
174 provider_id: self.id.clone(),
175 provider_name: self.name.clone(),
176 model,
177 state: self.state.clone(),
178 http_client: self.http_client.clone(),
179 request_limiter: RateLimiter::new(4),
180 })
181 }
182}
183
184impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
185 type ObservableEntity = State;
186
187 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
188 Some(self.state.clone())
189 }
190}
191
192impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
193 fn id(&self) -> LanguageModelProviderId {
194 self.id.clone()
195 }
196
197 fn name(&self) -> LanguageModelProviderName {
198 self.name.clone()
199 }
200
201 fn icon(&self) -> IconName {
202 IconName::AiOpenAiCompat
203 }
204
205 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
206 self.state
207 .read(cx)
208 .settings
209 .available_models
210 .first()
211 .map(|model| self.create_language_model(model.clone()))
212 }
213
214 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
215 None
216 }
217
218 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
219 self.state
220 .read(cx)
221 .settings
222 .available_models
223 .iter()
224 .map(|model| self.create_language_model(model.clone()))
225 .collect()
226 }
227
228 fn is_authenticated(&self, cx: &App) -> bool {
229 self.state.read(cx).is_authenticated()
230 }
231
232 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
233 self.state.update(cx, |state, cx| state.authenticate(cx))
234 }
235
236 fn configuration_view(
237 &self,
238 _target_agent: language_model::ConfigurationViewTargetAgent,
239 window: &mut Window,
240 cx: &mut App,
241 ) -> AnyView {
242 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
243 .into()
244 }
245
246 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247 self.state.update(cx, |state, cx| state.reset_api_key(cx))
248 }
249}
250
251pub struct OpenAiCompatibleLanguageModel {
252 id: LanguageModelId,
253 provider_id: LanguageModelProviderId,
254 provider_name: LanguageModelProviderName,
255 model: AvailableModel,
256 state: gpui::Entity<State>,
257 http_client: Arc<dyn HttpClient>,
258 request_limiter: RateLimiter,
259}
260
261impl OpenAiCompatibleLanguageModel {
262 fn stream_completion(
263 &self,
264 request: open_ai::Request,
265 cx: &AsyncApp,
266 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
267 {
268 let http_client = self.http_client.clone();
269 let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
270 (state.api_key.clone(), state.settings.api_url.clone())
271 }) else {
272 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
273 };
274
275 let provider = self.provider_name.clone();
276 let future = self.request_limiter.stream(async move {
277 let Some(api_key) = api_key else {
278 return Err(LanguageModelCompletionError::NoApiKey { provider });
279 };
280 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
281 let response = request.await?;
282 Ok(response)
283 });
284
285 async move { Ok(future.await?.boxed()) }.boxed()
286 }
287}
288
289impl LanguageModel for OpenAiCompatibleLanguageModel {
290 fn id(&self) -> LanguageModelId {
291 self.id.clone()
292 }
293
294 fn name(&self) -> LanguageModelName {
295 LanguageModelName::from(
296 self.model
297 .display_name
298 .clone()
299 .unwrap_or_else(|| self.model.name.clone()),
300 )
301 }
302
303 fn provider_id(&self) -> LanguageModelProviderId {
304 self.provider_id.clone()
305 }
306
307 fn provider_name(&self) -> LanguageModelProviderName {
308 self.provider_name.clone()
309 }
310
311 fn supports_tools(&self) -> bool {
312 self.model.capabilities.tools
313 }
314
315 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
316 LanguageModelToolSchemaFormat::JsonSchemaSubset
317 }
318
319 fn supports_images(&self) -> bool {
320 self.model.capabilities.images
321 }
322
323 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
324 match choice {
325 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
326 LanguageModelToolChoice::Any => self.model.capabilities.tools,
327 LanguageModelToolChoice::None => true,
328 }
329 }
330
331 fn telemetry_id(&self) -> String {
332 format!("openai/{}", self.model.name)
333 }
334
335 fn max_token_count(&self) -> u64 {
336 self.model.max_tokens
337 }
338
339 fn max_output_tokens(&self) -> Option<u64> {
340 self.model.max_output_tokens
341 }
342
343 fn count_tokens(
344 &self,
345 request: LanguageModelRequest,
346 cx: &App,
347 ) -> BoxFuture<'static, Result<u64>> {
348 let max_token_count = self.max_token_count();
349 cx.background_spawn(async move {
350 let messages = super::open_ai::collect_tiktoken_messages(request);
351 let model = if max_token_count >= 100_000 {
352 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
353 "gpt-4o"
354 } else {
355 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
356 // supported with this tiktoken method
357 "gpt-4"
358 };
359 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
360 })
361 .boxed()
362 }
363
364 fn stream_completion(
365 &self,
366 request: LanguageModelRequest,
367 cx: &AsyncApp,
368 ) -> BoxFuture<
369 'static,
370 Result<
371 futures::stream::BoxStream<
372 'static,
373 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
374 >,
375 LanguageModelCompletionError,
376 >,
377 > {
378 let request = into_open_ai(
379 request,
380 &self.model.name,
381 self.model.capabilities.parallel_tool_calls,
382 self.model.capabilities.prompt_cache_key,
383 self.max_output_tokens(),
384 None,
385 );
386 let completions = self.stream_completion(request, cx);
387 async move {
388 let mapper = OpenAiEventMapper::new();
389 Ok(mapper.map_stream(completions.await?).boxed())
390 }
391 .boxed()
392 }
393}
394
395struct ConfigurationView {
396 api_key_editor: Entity<SingleLineInput>,
397 state: gpui::Entity<State>,
398 load_credentials_task: Option<Task<()>>,
399}
400
401impl ConfigurationView {
402 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
403 let api_key_editor = cx.new(|cx| {
404 SingleLineInput::new(
405 window,
406 cx,
407 "000000000000000000000000000000000000000000000000000",
408 )
409 });
410
411 cx.observe(&state, |_, _, cx| {
412 cx.notify();
413 })
414 .detach();
415
416 let load_credentials_task = Some(cx.spawn_in(window, {
417 let state = state.clone();
418 async move |this, cx| {
419 if let Some(task) = state
420 .update(cx, |state, cx| state.authenticate(cx))
421 .log_err()
422 {
423 // We don't log an error, because "not signed in" is also an error.
424 let _ = task.await;
425 }
426 this.update(cx, |this, cx| {
427 this.load_credentials_task = None;
428 cx.notify();
429 })
430 .log_err();
431 }
432 }));
433
434 Self {
435 api_key_editor,
436 state,
437 load_credentials_task,
438 }
439 }
440
441 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
442 let api_key = self
443 .api_key_editor
444 .read(cx)
445 .editor()
446 .read(cx)
447 .text(cx)
448 .trim()
449 .to_string();
450
451 // Don't proceed if no API key is provided and we're not authenticated
452 if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
453 return;
454 }
455
456 let state = self.state.clone();
457 cx.spawn_in(window, async move |_, cx| {
458 state
459 .update(cx, |state, cx| state.set_api_key(api_key, cx))?
460 .await
461 })
462 .detach_and_log_err(cx);
463
464 cx.notify();
465 }
466
467 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
468 self.api_key_editor.update(cx, |input, cx| {
469 input.editor.update(cx, |editor, cx| {
470 editor.set_text("", window, cx);
471 });
472 });
473
474 let state = self.state.clone();
475 cx.spawn_in(window, async move |_, cx| {
476 state.update(cx, |state, cx| state.reset_api_key(cx))?.await
477 })
478 .detach_and_log_err(cx);
479
480 cx.notify();
481 }
482
483 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
484 !self.state.read(cx).is_authenticated()
485 }
486}
487
488impl Render for ConfigurationView {
489 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
490 let env_var_set = self.state.read(cx).api_key_from_env;
491 let env_var_name = self.state.read(cx).env_var_name.clone();
492
493 let api_key_section = if self.should_render_editor(cx) {
494 v_flex()
495 .on_action(cx.listener(Self::save_api_key))
496 .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
497 .child(
498 div()
499 .pt(DynamicSpacing::Base04.rems(cx))
500 .child(self.api_key_editor.clone())
501 )
502 .child(
503 Label::new(
504 format!("You can also assign the {env_var_name} environment variable and restart Zed."),
505 )
506 .size(LabelSize::Small).color(Color::Muted),
507 )
508 .into_any()
509 } else {
510 h_flex()
511 .mt_1()
512 .p_1()
513 .justify_between()
514 .rounded_md()
515 .border_1()
516 .border_color(cx.theme().colors().border)
517 .bg(cx.theme().colors().background)
518 .child(
519 h_flex()
520 .gap_1()
521 .child(Icon::new(IconName::Check).color(Color::Success))
522 .child(Label::new(if env_var_set {
523 format!("API key set in {env_var_name} environment variable.")
524 } else {
525 "API key configured.".to_string()
526 })),
527 )
528 .child(
529 Button::new("reset-api-key", "Reset API Key")
530 .label_size(LabelSize::Small)
531 .icon(IconName::Undo)
532 .icon_size(IconSize::Small)
533 .icon_position(IconPosition::Start)
534 .layer(ElevationIndex::ModalSurface)
535 .when(env_var_set, |this| {
536 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
537 })
538 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
539 )
540 .into_any()
541 };
542
543 if self.load_credentials_task.is_some() {
544 div().child(Label::new("Loading credentials…")).into_any()
545 } else {
546 v_flex().size_full().child(api_key_section).into_any()
547 }
548 }
549}