1use anyhow::Result;
2use convert_case::{Case, Casing};
3use credentials_provider::CredentialsProvider;
4use futures::{FutureExt, StreamExt, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
9 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
11 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
12};
13use menu;
14use open_ai::{
15 ResponseStreamEvent,
16 responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response},
17 stream_completion,
18};
19use settings::{Settings, SettingsStore};
20use std::sync::Arc;
21use ui::{ElevationIndex, Tooltip, prelude::*};
22use ui_input::InputField;
23use util::ResultExt;
24
25use crate::provider::open_ai::{
26 OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
27};
28pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
29pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
30
31#[derive(Default, Clone, Debug, PartialEq)]
32pub struct OpenAiCompatibleSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37pub struct OpenAiCompatibleLanguageModelProvider {
38 id: LanguageModelProviderId,
39 name: LanguageModelProviderName,
40 http_client: Arc<dyn HttpClient>,
41 state: Entity<State>,
42}
43
44pub struct State {
45 id: Arc<str>,
46 api_key_state: ApiKeyState,
47 settings: OpenAiCompatibleSettings,
48 credentials_provider: Arc<dyn CredentialsProvider>,
49}
50
51impl State {
52 fn is_authenticated(&self) -> bool {
53 self.api_key_state.has_key()
54 }
55
56 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
57 let credentials_provider = self.credentials_provider.clone();
58 let api_url = SharedString::new(self.settings.api_url.as_str());
59 self.api_key_state.store(
60 api_url,
61 api_key,
62 |this| &mut this.api_key_state,
63 credentials_provider,
64 cx,
65 )
66 }
67
68 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
69 let credentials_provider = self.credentials_provider.clone();
70 let api_url = SharedString::new(self.settings.api_url.clone());
71 self.api_key_state.load_if_needed(
72 api_url,
73 |this| &mut this.api_key_state,
74 credentials_provider,
75 cx,
76 )
77 }
78}
79
80impl OpenAiCompatibleLanguageModelProvider {
81 pub fn new(
82 id: Arc<str>,
83 http_client: Arc<dyn HttpClient>,
84 credentials_provider: Arc<dyn CredentialsProvider>,
85 cx: &mut App,
86 ) -> Self {
87 fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
88 crate::AllLanguageModelSettings::get_global(cx)
89 .openai_compatible
90 .get(id)
91 }
92
93 let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
94 let state = cx.new(|cx| {
95 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
96 let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
97 return;
98 };
99 if &this.settings != &settings {
100 let credentials_provider = this.credentials_provider.clone();
101 let api_url = SharedString::new(settings.api_url.as_str());
102 this.api_key_state.handle_url_change(
103 api_url,
104 |this| &mut this.api_key_state,
105 credentials_provider,
106 cx,
107 );
108 this.settings = settings;
109 cx.notify();
110 }
111 })
112 .detach();
113 let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
114 State {
115 id: id.clone(),
116 api_key_state: ApiKeyState::new(
117 SharedString::new(settings.api_url.as_str()),
118 EnvVar::new(api_key_env_var_name),
119 ),
120 settings,
121 credentials_provider,
122 }
123 });
124
125 Self {
126 id: id.clone().into(),
127 name: id.into(),
128 http_client,
129 state,
130 }
131 }
132
133 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
134 Arc::new(OpenAiCompatibleLanguageModel {
135 id: LanguageModelId::from(model.name.clone()),
136 provider_id: self.id.clone(),
137 provider_name: self.name.clone(),
138 model,
139 state: self.state.clone(),
140 http_client: self.http_client.clone(),
141 request_limiter: RateLimiter::new(4),
142 })
143 }
144}
145
146impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
147 type ObservableEntity = State;
148
149 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
150 Some(self.state.clone())
151 }
152}
153
154impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
155 fn id(&self) -> LanguageModelProviderId {
156 self.id.clone()
157 }
158
159 fn name(&self) -> LanguageModelProviderName {
160 self.name.clone()
161 }
162
163 fn icon(&self) -> IconOrSvg {
164 IconOrSvg::Icon(IconName::AiOpenAiCompat)
165 }
166
167 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
168 self.state
169 .read(cx)
170 .settings
171 .available_models
172 .first()
173 .map(|model| self.create_language_model(model.clone()))
174 }
175
176 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
177 None
178 }
179
180 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
181 self.state
182 .read(cx)
183 .settings
184 .available_models
185 .iter()
186 .map(|model| self.create_language_model(model.clone()))
187 .collect()
188 }
189
190 fn is_authenticated(&self, cx: &App) -> bool {
191 self.state.read(cx).is_authenticated()
192 }
193
194 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
195 self.state.update(cx, |state, cx| state.authenticate(cx))
196 }
197
198 fn configuration_view(
199 &self,
200 _target_agent: language_model::ConfigurationViewTargetAgent,
201 window: &mut Window,
202 cx: &mut App,
203 ) -> AnyView {
204 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
205 .into()
206 }
207
208 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
209 self.state
210 .update(cx, |state, cx| state.set_api_key(None, cx))
211 }
212}
213
214pub struct OpenAiCompatibleLanguageModel {
215 id: LanguageModelId,
216 provider_id: LanguageModelProviderId,
217 provider_name: LanguageModelProviderName,
218 model: AvailableModel,
219 state: Entity<State>,
220 http_client: Arc<dyn HttpClient>,
221 request_limiter: RateLimiter,
222}
223
224impl OpenAiCompatibleLanguageModel {
225 fn stream_completion(
226 &self,
227 request: open_ai::Request,
228 cx: &AsyncApp,
229 ) -> BoxFuture<
230 'static,
231 Result<
232 futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
233 LanguageModelCompletionError,
234 >,
235 > {
236 let http_client = self.http_client.clone();
237
238 let (api_key, api_url) = self.state.read_with(cx, |state, _cx| {
239 let api_url = &state.settings.api_url;
240 (
241 state.api_key_state.key(api_url),
242 state.settings.api_url.clone(),
243 )
244 });
245
246 let provider = self.provider_name.clone();
247 let future = self.request_limiter.stream(async move {
248 let Some(api_key) = api_key else {
249 return Err(LanguageModelCompletionError::NoApiKey { provider });
250 };
251 let request = stream_completion(
252 http_client.as_ref(),
253 provider.0.as_str(),
254 &api_url,
255 &api_key,
256 request,
257 );
258 let response = request.await?;
259 Ok(response)
260 });
261
262 async move { Ok(future.await?.boxed()) }.boxed()
263 }
264
265 fn stream_response(
266 &self,
267 request: ResponseRequest,
268 cx: &AsyncApp,
269 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
270 {
271 let http_client = self.http_client.clone();
272
273 let (api_key, api_url) = self.state.read_with(cx, |state, _cx| {
274 let api_url = &state.settings.api_url;
275 (
276 state.api_key_state.key(api_url),
277 state.settings.api_url.clone(),
278 )
279 });
280
281 let provider = self.provider_name.clone();
282 let future = self.request_limiter.stream(async move {
283 let Some(api_key) = api_key else {
284 return Err(LanguageModelCompletionError::NoApiKey { provider });
285 };
286 let request = stream_response(
287 http_client.as_ref(),
288 provider.0.as_str(),
289 &api_url,
290 &api_key,
291 request,
292 vec![],
293 );
294 let response = request.await?;
295 Ok(response)
296 });
297
298 async move { Ok(future.await?.boxed()) }.boxed()
299 }
300}
301
302impl LanguageModel for OpenAiCompatibleLanguageModel {
303 fn id(&self) -> LanguageModelId {
304 self.id.clone()
305 }
306
307 fn name(&self) -> LanguageModelName {
308 LanguageModelName::from(
309 self.model
310 .display_name
311 .clone()
312 .unwrap_or_else(|| self.model.name.clone()),
313 )
314 }
315
316 fn provider_id(&self) -> LanguageModelProviderId {
317 self.provider_id.clone()
318 }
319
320 fn provider_name(&self) -> LanguageModelProviderName {
321 self.provider_name.clone()
322 }
323
324 fn supports_tools(&self) -> bool {
325 self.model.capabilities.tools
326 }
327
328 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
329 LanguageModelToolSchemaFormat::JsonSchemaSubset
330 }
331
332 fn supports_images(&self) -> bool {
333 self.model.capabilities.images
334 }
335
336 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
337 match choice {
338 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
339 LanguageModelToolChoice::Any => self.model.capabilities.tools,
340 LanguageModelToolChoice::None => true,
341 }
342 }
343
344 fn supports_streaming_tools(&self) -> bool {
345 true
346 }
347
348 fn supports_split_token_display(&self) -> bool {
349 true
350 }
351
352 fn telemetry_id(&self) -> String {
353 format!("openai/{}", self.model.name)
354 }
355
356 fn max_token_count(&self) -> u64 {
357 self.model.max_tokens
358 }
359
360 fn max_output_tokens(&self) -> Option<u64> {
361 self.model.max_output_tokens
362 }
363
364 fn count_tokens(
365 &self,
366 request: LanguageModelRequest,
367 cx: &App,
368 ) -> BoxFuture<'static, Result<u64>> {
369 let max_token_count = self.max_token_count();
370 cx.background_spawn(async move {
371 let messages = super::open_ai::collect_tiktoken_messages(request);
372 let model = if max_token_count >= 100_000 {
373 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
374 "gpt-4o"
375 } else {
376 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
377 // supported with this tiktoken method
378 "gpt-4"
379 };
380 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
381 })
382 .boxed()
383 }
384
385 fn stream_completion(
386 &self,
387 request: LanguageModelRequest,
388 cx: &AsyncApp,
389 ) -> BoxFuture<
390 'static,
391 Result<
392 futures::stream::BoxStream<
393 'static,
394 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
395 >,
396 LanguageModelCompletionError,
397 >,
398 > {
399 if self.model.capabilities.chat_completions {
400 let request = into_open_ai(
401 request,
402 &self.model.name,
403 self.model.capabilities.parallel_tool_calls,
404 self.model.capabilities.prompt_cache_key,
405 self.max_output_tokens(),
406 self.model.reasoning_effort,
407 );
408 let completions = self.stream_completion(request, cx);
409 async move {
410 let mapper = OpenAiEventMapper::new();
411 Ok(mapper.map_stream(completions.await?).boxed())
412 }
413 .boxed()
414 } else {
415 let request = into_open_ai_response(
416 request,
417 &self.model.name,
418 self.model.capabilities.parallel_tool_calls,
419 self.model.capabilities.prompt_cache_key,
420 self.max_output_tokens(),
421 self.model.reasoning_effort,
422 );
423 let completions = self.stream_response(request, cx);
424 async move {
425 let mapper = OpenAiResponseEventMapper::new();
426 Ok(mapper.map_stream(completions.await?).boxed())
427 }
428 .boxed()
429 }
430 }
431}
432
433struct ConfigurationView {
434 api_key_editor: Entity<InputField>,
435 state: Entity<State>,
436 load_credentials_task: Option<Task<()>>,
437}
438
439impl ConfigurationView {
440 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
441 let api_key_editor = cx.new(|cx| {
442 InputField::new(
443 window,
444 cx,
445 "000000000000000000000000000000000000000000000000000",
446 )
447 });
448
449 cx.observe(&state, |_, _, cx| {
450 cx.notify();
451 })
452 .detach();
453
454 let load_credentials_task = Some(cx.spawn_in(window, {
455 let state = state.clone();
456 async move |this, cx| {
457 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
458 // We don't log an error, because "not signed in" is also an error.
459 let _ = task.await;
460 }
461 this.update(cx, |this, cx| {
462 this.load_credentials_task = None;
463 cx.notify();
464 })
465 .log_err();
466 }
467 }));
468
469 Self {
470 api_key_editor,
471 state,
472 load_credentials_task,
473 }
474 }
475
476 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
477 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
478 if api_key.is_empty() {
479 return;
480 }
481
482 // url changes can cause the editor to be displayed again
483 self.api_key_editor
484 .update(cx, |input, cx| input.set_text("", window, cx));
485
486 let state = self.state.clone();
487 cx.spawn_in(window, async move |_, cx| {
488 state
489 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
490 .await
491 })
492 .detach_and_log_err(cx);
493 }
494
495 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
496 self.api_key_editor
497 .update(cx, |input, cx| input.set_text("", window, cx));
498
499 let state = self.state.clone();
500 cx.spawn_in(window, async move |_, cx| {
501 state
502 .update(cx, |state, cx| state.set_api_key(None, cx))
503 .await
504 })
505 .detach_and_log_err(cx);
506 }
507
508 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
509 !self.state.read(cx).is_authenticated()
510 }
511}
512
513impl Render for ConfigurationView {
514 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
515 let state = self.state.read(cx);
516 let env_var_set = state.api_key_state.is_from_env_var();
517 let env_var_name = state.api_key_state.env_var_name();
518
519 let api_key_section = if self.should_render_editor(cx) {
520 v_flex()
521 .on_action(cx.listener(Self::save_api_key))
522 .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
523 .child(
524 div()
525 .pt(DynamicSpacing::Base04.rems(cx))
526 .child(self.api_key_editor.clone())
527 )
528 .child(
529 Label::new(
530 format!("You can also set the {env_var_name} environment variable and restart Zed."),
531 )
532 .size(LabelSize::Small).color(Color::Muted),
533 )
534 .into_any()
535 } else {
536 h_flex()
537 .mt_1()
538 .p_1()
539 .justify_between()
540 .rounded_md()
541 .border_1()
542 .border_color(cx.theme().colors().border)
543 .bg(cx.theme().colors().background)
544 .child(
545 h_flex()
546 .flex_1()
547 .min_w_0()
548 .gap_1()
549 .child(Icon::new(IconName::Check).color(Color::Success))
550 .child(
551 div()
552 .w_full()
553 .overflow_x_hidden()
554 .text_ellipsis()
555 .child(Label::new(
556 if env_var_set {
557 format!("API key set in {env_var_name} environment variable")
558 } else {
559 format!("API key configured for {}", &state.settings.api_url)
560 }
561 ))
562 ),
563 )
564 .child(
565 h_flex()
566 .flex_shrink_0()
567 .child(
568 Button::new("reset-api-key", "Reset API Key")
569 .label_size(LabelSize::Small)
570 .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
571 .layer(ElevationIndex::ModalSurface)
572 .when(env_var_set, |this| {
573 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
574 })
575 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
576 ),
577 )
578 .into_any()
579 };
580
581 if self.load_credentials_task.is_some() {
582 div().child(Label::new("Loading credentials…")).into_any()
583 } else {
584 v_flex().size_full().child(api_key_section).into_any()
585 }
586 }
587}