1use anyhow::Result;
2use convert_case::{Case, Casing};
3use credentials_provider::CredentialsProvider;
4use futures::{FutureExt, StreamExt, future::BoxFuture};
5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
6use http_client::HttpClient;
7use language_model::{
8 ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
9 LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
10 LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
11 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
12};
13use menu;
14use open_ai::{
15 ResponseStreamEvent,
16 responses::{Request as ResponseRequest, StreamEvent as ResponsesStreamEvent, stream_response},
17 stream_completion,
18};
19use settings::{Settings, SettingsStore};
20use std::sync::Arc;
21use ui::{ElevationIndex, Tooltip, prelude::*};
22use ui_input::InputField;
23use util::ResultExt;
24
25use crate::provider::open_ai::{
26 OpenAiEventMapper, OpenAiResponseEventMapper, into_open_ai, into_open_ai_response,
27};
28pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
29pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
30
31#[derive(Default, Clone, Debug, PartialEq)]
32pub struct OpenAiCompatibleSettings {
33 pub api_url: String,
34 pub available_models: Vec<AvailableModel>,
35}
36
37pub struct OpenAiCompatibleLanguageModelProvider {
38 id: LanguageModelProviderId,
39 name: LanguageModelProviderName,
40 http_client: Arc<dyn HttpClient>,
41 state: Entity<State>,
42}
43
44pub struct State {
45 id: Arc<str>,
46 api_key_state: ApiKeyState,
47 settings: OpenAiCompatibleSettings,
48 credentials_provider: Arc<dyn CredentialsProvider>,
49}
50
51impl State {
52 fn is_authenticated(&self) -> bool {
53 self.api_key_state.has_key()
54 }
55
56 fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
57 let credentials_provider = self.credentials_provider.clone();
58 let api_url = SharedString::new(self.settings.api_url.as_str());
59 self.api_key_state.store(
60 api_url,
61 api_key,
62 |this| &mut this.api_key_state,
63 credentials_provider,
64 cx,
65 )
66 }
67
68 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
69 let credentials_provider = self.credentials_provider.clone();
70 let api_url = SharedString::new(self.settings.api_url.clone());
71 self.api_key_state.load_if_needed(
72 api_url,
73 |this| &mut this.api_key_state,
74 credentials_provider,
75 cx,
76 )
77 }
78}
79
80impl OpenAiCompatibleLanguageModelProvider {
81 pub fn new(
82 id: Arc<str>,
83 http_client: Arc<dyn HttpClient>,
84 credentials_provider: Arc<dyn CredentialsProvider>,
85 cx: &mut App,
86 ) -> Self {
87 fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
88 crate::AllLanguageModelSettings::get_global(cx)
89 .openai_compatible
90 .get(id)
91 }
92
93 let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
94 let state = cx.new(|cx| {
95 cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
96 let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
97 return;
98 };
99 if &this.settings != &settings {
100 let credentials_provider = this.credentials_provider.clone();
101 let api_url = SharedString::new(settings.api_url.as_str());
102 this.api_key_state.handle_url_change(
103 api_url,
104 |this| &mut this.api_key_state,
105 credentials_provider,
106 cx,
107 );
108 this.settings = settings;
109 cx.notify();
110 }
111 })
112 .detach();
113 let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
114 State {
115 id: id.clone(),
116 api_key_state: ApiKeyState::new(
117 SharedString::new(settings.api_url.as_str()),
118 EnvVar::new(api_key_env_var_name),
119 ),
120 settings,
121 credentials_provider,
122 }
123 });
124
125 Self {
126 id: id.clone().into(),
127 name: id.into(),
128 http_client,
129 state,
130 }
131 }
132
133 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
134 Arc::new(OpenAiCompatibleLanguageModel {
135 id: LanguageModelId::from(model.name.clone()),
136 provider_id: self.id.clone(),
137 provider_name: self.name.clone(),
138 model,
139 state: self.state.clone(),
140 http_client: self.http_client.clone(),
141 request_limiter: RateLimiter::new(4),
142 })
143 }
144}
145
146impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
147 type ObservableEntity = State;
148
149 fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
150 Some(self.state.clone())
151 }
152}
153
154impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
155 fn id(&self) -> LanguageModelProviderId {
156 self.id.clone()
157 }
158
159 fn name(&self) -> LanguageModelProviderName {
160 self.name.clone()
161 }
162
163 fn icon(&self) -> IconOrSvg {
164 IconOrSvg::Icon(IconName::AiOpenAiCompat)
165 }
166
167 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
168 self.state
169 .read(cx)
170 .settings
171 .available_models
172 .first()
173 .map(|model| self.create_language_model(model.clone()))
174 }
175
176 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
177 None
178 }
179
180 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
181 self.state
182 .read(cx)
183 .settings
184 .available_models
185 .iter()
186 .map(|model| self.create_language_model(model.clone()))
187 .collect()
188 }
189
190 fn is_authenticated(&self, cx: &App) -> bool {
191 self.state.read(cx).is_authenticated()
192 }
193
194 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
195 self.state.update(cx, |state, cx| state.authenticate(cx))
196 }
197
198 fn configuration_view(
199 &self,
200 _target_agent: language_model::ConfigurationViewTargetAgent,
201 window: &mut Window,
202 cx: &mut App,
203 ) -> AnyView {
204 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
205 .into()
206 }
207
208 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
209 self.state
210 .update(cx, |state, cx| state.set_api_key(None, cx))
211 }
212}
213
214pub struct OpenAiCompatibleLanguageModel {
215 id: LanguageModelId,
216 provider_id: LanguageModelProviderId,
217 provider_name: LanguageModelProviderName,
218 model: AvailableModel,
219 state: Entity<State>,
220 http_client: Arc<dyn HttpClient>,
221 request_limiter: RateLimiter,
222}
223
224impl OpenAiCompatibleLanguageModel {
225 fn stream_completion(
226 &self,
227 request: open_ai::Request,
228 cx: &AsyncApp,
229 ) -> BoxFuture<
230 'static,
231 Result<
232 futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
233 LanguageModelCompletionError,
234 >,
235 > {
236 let http_client = self.http_client.clone();
237
238 let (api_key, api_url) = self.state.read_with(cx, |state, _cx| {
239 let api_url = &state.settings.api_url;
240 (
241 state.api_key_state.key(api_url),
242 state.settings.api_url.clone(),
243 )
244 });
245
246 let provider = self.provider_name.clone();
247 let future = self.request_limiter.stream(async move {
248 let Some(api_key) = api_key else {
249 return Err(LanguageModelCompletionError::NoApiKey { provider });
250 };
251 let request = stream_completion(
252 http_client.as_ref(),
253 provider.0.as_str(),
254 &api_url,
255 &api_key,
256 request,
257 );
258 let response = request.await?;
259 Ok(response)
260 });
261
262 async move { Ok(future.await?.boxed()) }.boxed()
263 }
264
265 fn stream_response(
266 &self,
267 request: ResponseRequest,
268 cx: &AsyncApp,
269 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponsesStreamEvent>>>>
270 {
271 let http_client = self.http_client.clone();
272
273 let (api_key, api_url) = self.state.read_with(cx, |state, _cx| {
274 let api_url = &state.settings.api_url;
275 (
276 state.api_key_state.key(api_url),
277 state.settings.api_url.clone(),
278 )
279 });
280
281 let provider = self.provider_name.clone();
282 let future = self.request_limiter.stream(async move {
283 let Some(api_key) = api_key else {
284 return Err(LanguageModelCompletionError::NoApiKey { provider });
285 };
286 let request = stream_response(
287 http_client.as_ref(),
288 provider.0.as_str(),
289 &api_url,
290 &api_key,
291 request,
292 );
293 let response = request.await?;
294 Ok(response)
295 });
296
297 async move { Ok(future.await?.boxed()) }.boxed()
298 }
299}
300
301impl LanguageModel for OpenAiCompatibleLanguageModel {
302 fn id(&self) -> LanguageModelId {
303 self.id.clone()
304 }
305
306 fn name(&self) -> LanguageModelName {
307 LanguageModelName::from(
308 self.model
309 .display_name
310 .clone()
311 .unwrap_or_else(|| self.model.name.clone()),
312 )
313 }
314
315 fn provider_id(&self) -> LanguageModelProviderId {
316 self.provider_id.clone()
317 }
318
319 fn provider_name(&self) -> LanguageModelProviderName {
320 self.provider_name.clone()
321 }
322
323 fn supports_tools(&self) -> bool {
324 self.model.capabilities.tools
325 }
326
327 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
328 LanguageModelToolSchemaFormat::JsonSchemaSubset
329 }
330
331 fn supports_images(&self) -> bool {
332 self.model.capabilities.images
333 }
334
335 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
336 match choice {
337 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
338 LanguageModelToolChoice::Any => self.model.capabilities.tools,
339 LanguageModelToolChoice::None => true,
340 }
341 }
342
343 fn supports_streaming_tools(&self) -> bool {
344 true
345 }
346
347 fn supports_split_token_display(&self) -> bool {
348 true
349 }
350
351 fn telemetry_id(&self) -> String {
352 format!("openai/{}", self.model.name)
353 }
354
355 fn max_token_count(&self) -> u64 {
356 self.model.max_tokens
357 }
358
359 fn max_output_tokens(&self) -> Option<u64> {
360 self.model.max_output_tokens
361 }
362
363 fn count_tokens(
364 &self,
365 request: LanguageModelRequest,
366 cx: &App,
367 ) -> BoxFuture<'static, Result<u64>> {
368 let max_token_count = self.max_token_count();
369 cx.background_spawn(async move {
370 let messages = super::open_ai::collect_tiktoken_messages(request);
371 let model = if max_token_count >= 100_000 {
372 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
373 "gpt-4o"
374 } else {
375 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
376 // supported with this tiktoken method
377 "gpt-4"
378 };
379 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
380 })
381 .boxed()
382 }
383
384 fn stream_completion(
385 &self,
386 request: LanguageModelRequest,
387 cx: &AsyncApp,
388 ) -> BoxFuture<
389 'static,
390 Result<
391 futures::stream::BoxStream<
392 'static,
393 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
394 >,
395 LanguageModelCompletionError,
396 >,
397 > {
398 if self.model.capabilities.chat_completions {
399 let request = into_open_ai(
400 request,
401 &self.model.name,
402 self.model.capabilities.parallel_tool_calls,
403 self.model.capabilities.prompt_cache_key,
404 self.max_output_tokens(),
405 None,
406 );
407 let completions = self.stream_completion(request, cx);
408 async move {
409 let mapper = OpenAiEventMapper::new();
410 Ok(mapper.map_stream(completions.await?).boxed())
411 }
412 .boxed()
413 } else {
414 let request = into_open_ai_response(
415 request,
416 &self.model.name,
417 self.model.capabilities.parallel_tool_calls,
418 self.model.capabilities.prompt_cache_key,
419 self.max_output_tokens(),
420 None,
421 );
422 let completions = self.stream_response(request, cx);
423 async move {
424 let mapper = OpenAiResponseEventMapper::new();
425 Ok(mapper.map_stream(completions.await?).boxed())
426 }
427 .boxed()
428 }
429 }
430}
431
432struct ConfigurationView {
433 api_key_editor: Entity<InputField>,
434 state: Entity<State>,
435 load_credentials_task: Option<Task<()>>,
436}
437
438impl ConfigurationView {
439 fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
440 let api_key_editor = cx.new(|cx| {
441 InputField::new(
442 window,
443 cx,
444 "000000000000000000000000000000000000000000000000000",
445 )
446 });
447
448 cx.observe(&state, |_, _, cx| {
449 cx.notify();
450 })
451 .detach();
452
453 let load_credentials_task = Some(cx.spawn_in(window, {
454 let state = state.clone();
455 async move |this, cx| {
456 if let Some(task) = Some(state.update(cx, |state, cx| state.authenticate(cx))) {
457 // We don't log an error, because "not signed in" is also an error.
458 let _ = task.await;
459 }
460 this.update(cx, |this, cx| {
461 this.load_credentials_task = None;
462 cx.notify();
463 })
464 .log_err();
465 }
466 }));
467
468 Self {
469 api_key_editor,
470 state,
471 load_credentials_task,
472 }
473 }
474
475 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
476 let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
477 if api_key.is_empty() {
478 return;
479 }
480
481 // url changes can cause the editor to be displayed again
482 self.api_key_editor
483 .update(cx, |input, cx| input.set_text("", window, cx));
484
485 let state = self.state.clone();
486 cx.spawn_in(window, async move |_, cx| {
487 state
488 .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
489 .await
490 })
491 .detach_and_log_err(cx);
492 }
493
494 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
495 self.api_key_editor
496 .update(cx, |input, cx| input.set_text("", window, cx));
497
498 let state = self.state.clone();
499 cx.spawn_in(window, async move |_, cx| {
500 state
501 .update(cx, |state, cx| state.set_api_key(None, cx))
502 .await
503 })
504 .detach_and_log_err(cx);
505 }
506
507 fn should_render_editor(&self, cx: &Context<Self>) -> bool {
508 !self.state.read(cx).is_authenticated()
509 }
510}
511
512impl Render for ConfigurationView {
513 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
514 let state = self.state.read(cx);
515 let env_var_set = state.api_key_state.is_from_env_var();
516 let env_var_name = state.api_key_state.env_var_name();
517
518 let api_key_section = if self.should_render_editor(cx) {
519 v_flex()
520 .on_action(cx.listener(Self::save_api_key))
521 .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
522 .child(
523 div()
524 .pt(DynamicSpacing::Base04.rems(cx))
525 .child(self.api_key_editor.clone())
526 )
527 .child(
528 Label::new(
529 format!("You can also set the {env_var_name} environment variable and restart Zed."),
530 )
531 .size(LabelSize::Small).color(Color::Muted),
532 )
533 .into_any()
534 } else {
535 h_flex()
536 .mt_1()
537 .p_1()
538 .justify_between()
539 .rounded_md()
540 .border_1()
541 .border_color(cx.theme().colors().border)
542 .bg(cx.theme().colors().background)
543 .child(
544 h_flex()
545 .flex_1()
546 .min_w_0()
547 .gap_1()
548 .child(Icon::new(IconName::Check).color(Color::Success))
549 .child(
550 div()
551 .w_full()
552 .overflow_x_hidden()
553 .text_ellipsis()
554 .child(Label::new(
555 if env_var_set {
556 format!("API key set in {env_var_name} environment variable")
557 } else {
558 format!("API key configured for {}", &state.settings.api_url)
559 }
560 ))
561 ),
562 )
563 .child(
564 h_flex()
565 .flex_shrink_0()
566 .child(
567 Button::new("reset-api-key", "Reset API Key")
568 .label_size(LabelSize::Small)
569 .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
570 .layer(ElevationIndex::ModalSurface)
571 .when(env_var_set, |this| {
572 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
573 })
574 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
575 ),
576 )
577 .into_any()
578 };
579
580 if self.load_credentials_task.is_some() {
581 div().child(Label::new("Loading credentials…")).into_any()
582 } else {
583 v_flex().size_full().child(api_key_section).into_any()
584 }
585 }
586}