1use anyhow::{Context as _, Result, anyhow};
2use credentials_provider::CredentialsProvider;
3
4use convert_case::{Case, Casing};
5use futures::{FutureExt, StreamExt, future::BoxFuture};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
10 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
11 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
12 LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
13};
14use menu;
15use open_ai::{ResponseStreamEvent, stream_completion};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::sync::Arc;
20
21use ui::{ElevationIndex, Tooltip, prelude::*};
22use ui_input::SingleLineInput;
23use util::ResultExt;
24
25use crate::AllLanguageModelSettings;
26use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
27
28#[derive(Default, Clone, Debug, PartialEq)]
29pub struct OpenAiCompatibleSettings {
30 pub api_url: String,
31 pub available_models: Vec<AvailableModel>,
32}
33
34#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
35pub struct AvailableModel {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: u64,
39 pub max_output_tokens: Option<u64>,
40 pub max_completion_tokens: Option<u64>,
41 #[serde(default)]
42 pub capabilities: ModelCapabilities,
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
46pub struct ModelCapabilities {
47 pub tools: bool,
48 pub images: bool,
49 pub parallel_tool_calls: bool,
50 pub prompt_cache_key: bool,
51}
52
53impl Default for ModelCapabilities {
54 fn default() -> Self {
55 Self {
56 tools: true,
57 images: false,
58 parallel_tool_calls: false,
59 prompt_cache_key: false,
60 }
61 }
62}
63
64pub struct OpenAiCompatibleLanguageModelProvider {
65 id: LanguageModelProviderId,
66 name: LanguageModelProviderName,
67 http_client: Arc<dyn HttpClient>,
68 state: gpui::Entity<State>,
69}
70
71pub struct State {
72 id: Arc<str>,
73 env_var_name: Arc<str>,
74 api_key: Option<String>,
75 api_key_from_env: bool,
76 settings: OpenAiCompatibleSettings,
77 _subscription: Subscription,
78}
79
80impl State {
81 fn is_authenticated(&self) -> bool {
82 self.api_key.is_some()
83 }
84
85 fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
86 let credentials_provider = <dyn CredentialsProvider>::global(cx);
87 let api_url = self.settings.api_url.clone();
88 cx.spawn(async move |this, cx| {
89 credentials_provider
90 .delete_credentials(&api_url, cx)
91 .await
92 .log_err();
93 this.update(cx, |this, cx| {
94 this.api_key = None;
95 this.api_key_from_env = false;
96 cx.notify();
97 })
98 })
99 }
100
101 fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
102 let credentials_provider = <dyn CredentialsProvider>::global(cx);
103 let api_url = self.settings.api_url.clone();
104 cx.spawn(async move |this, cx| {
105 credentials_provider
106 .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
107 .await
108 .log_err();
109 this.update(cx, |this, cx| {
110 this.api_key = Some(api_key);
111 cx.notify();
112 })
113 })
114 }
115
116 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
117 if self.is_authenticated() {
118 return Task::ready(Ok(()));
119 }
120
121 let credentials_provider = <dyn CredentialsProvider>::global(cx);
122 let env_var_name = self.env_var_name.clone();
123 let api_url = self.settings.api_url.clone();
124 cx.spawn(async move |this, cx| {
125 let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
126 (api_key, true)
127 } else {
128 let (_, api_key) = credentials_provider
129 .read_credentials(&api_url, cx)
130 .await?
131 .ok_or(AuthenticateError::CredentialsNotFound)?;
132 (
133 String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
134 false,
135 )
136 };
137 this.update(cx, |this, cx| {
138 this.api_key = Some(api_key);
139 this.api_key_from_env = from_env;
140 cx.notify();
141 })?;
142
143 Ok(())
144 })
145 }
146}
147
148impl OpenAiCompatibleLanguageModelProvider {
149 pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
150 fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
151 AllLanguageModelSettings::get_global(cx)
152 .openai_compatible
153 .get(id)
154 }
155
156 let state = cx.new(|cx| State {
157 id: id.clone(),
158 env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
159 settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
160 api_key: None,
161 api_key_from_env: false,
162 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
163 let Some(settings) = resolve_settings(&this.id, cx) else {
164 return;
165 };
166 if &this.settings != settings {
167 this.settings = settings.clone();
168 cx.notify();
169 }
170 }),
171 });
172
173 Self {
174 id: id.clone().into(),
175 name: id.into(),
176 http_client,
177 state,
178 }
179 }
180
181 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
182 Arc::new(OpenAiCompatibleLanguageModel {
183 id: LanguageModelId::from(model.name.clone()),
184 provider_id: self.id.clone(),
185 provider_name: self.name.clone(),
186 model,
187 state: self.state.clone(),
188 http_client: self.http_client.clone(),
189 request_limiter: RateLimiter::new(4),
190 })
191 }
192}
193
194impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
195 type ObservableEntity = State;
196
197 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
198 Some(self.state.clone())
199 }
200}
201
202impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
203 fn id(&self) -> LanguageModelProviderId {
204 self.id.clone()
205 }
206
207 fn name(&self) -> LanguageModelProviderName {
208 self.name.clone()
209 }
210
211 fn icon(&self) -> IconName {
212 IconName::AiOpenAiCompat
213 }
214
215 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
216 self.state
217 .read(cx)
218 .settings
219 .available_models
220 .first()
221 .map(|model| self.create_language_model(model.clone()))
222 }
223
224 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
225 None
226 }
227
228 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
229 self.state
230 .read(cx)
231 .settings
232 .available_models
233 .iter()
234 .map(|model| self.create_language_model(model.clone()))
235 .collect()
236 }
237
238 fn is_authenticated(&self, cx: &App) -> bool {
239 self.state.read(cx).is_authenticated()
240 }
241
242 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
243 self.state.update(cx, |state, cx| state.authenticate(cx))
244 }
245
246 fn configuration_view(
247 &self,
248 _target_agent: language_model::ConfigurationViewTargetAgent,
249 window: &mut Window,
250 cx: &mut App,
251 ) -> AnyView {
252 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
253 .into()
254 }
255
256 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
257 self.state.update(cx, |state, cx| state.reset_api_key(cx))
258 }
259}
260
261pub struct OpenAiCompatibleLanguageModel {
262 id: LanguageModelId,
263 provider_id: LanguageModelProviderId,
264 provider_name: LanguageModelProviderName,
265 model: AvailableModel,
266 state: gpui::Entity<State>,
267 http_client: Arc<dyn HttpClient>,
268 request_limiter: RateLimiter,
269}
270
271impl OpenAiCompatibleLanguageModel {
272 fn stream_completion(
273 &self,
274 request: open_ai::Request,
275 cx: &AsyncApp,
276 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
277 {
278 let http_client = self.http_client.clone();
279 let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
280 (state.api_key.clone(), state.settings.api_url.clone())
281 }) else {
282 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
283 };
284
285 let provider = self.provider_name.clone();
286 let future = self.request_limiter.stream(async move {
287 let Some(api_key) = api_key else {
288 return Err(LanguageModelCompletionError::NoApiKey { provider });
289 };
290 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
291 let response = request.await?;
292 Ok(response)
293 });
294
295 async move { Ok(future.await?.boxed()) }.boxed()
296 }
297}
298
299impl LanguageModel for OpenAiCompatibleLanguageModel {
300 fn id(&self) -> LanguageModelId {
301 self.id.clone()
302 }
303
304 fn name(&self) -> LanguageModelName {
305 LanguageModelName::from(
306 self.model
307 .display_name
308 .clone()
309 .unwrap_or_else(|| self.model.name.clone()),
310 )
311 }
312
313 fn provider_id(&self) -> LanguageModelProviderId {
314 self.provider_id.clone()
315 }
316
317 fn provider_name(&self) -> LanguageModelProviderName {
318 self.provider_name.clone()
319 }
320
321 fn supports_tools(&self) -> bool {
322 self.model.capabilities.tools
323 }
324
325 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
326 LanguageModelToolSchemaFormat::JsonSchemaSubset
327 }
328
329 fn supports_images(&self) -> bool {
330 self.model.capabilities.images
331 }
332
333 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
334 match choice {
335 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
336 LanguageModelToolChoice::Any => self.model.capabilities.tools,
337 LanguageModelToolChoice::None => true,
338 }
339 }
340
341 fn telemetry_id(&self) -> String {
342 format!("openai/{}", self.model.name)
343 }
344
345 fn max_token_count(&self) -> u64 {
346 self.model.max_tokens
347 }
348
349 fn max_output_tokens(&self) -> Option<u64> {
350 self.model.max_output_tokens
351 }
352
353 fn count_tokens(
354 &self,
355 request: LanguageModelRequest,
356 cx: &App,
357 ) -> BoxFuture<'static, Result<u64>> {
358 let max_token_count = self.max_token_count();
359 cx.background_spawn(async move {
360 let messages = super::open_ai::collect_tiktoken_messages(request);
361 let model = if max_token_count >= 100_000 {
362 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
363 "gpt-4o"
364 } else {
365 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
366 // supported with this tiktoken method
367 "gpt-4"
368 };
369 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
370 })
371 .boxed()
372 }
373
374 fn stream_completion(
375 &self,
376 request: LanguageModelRequest,
377 cx: &AsyncApp,
378 ) -> BoxFuture<
379 'static,
380 Result<
381 futures::stream::BoxStream<
382 'static,
383 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
384 >,
385 LanguageModelCompletionError,
386 >,
387 > {
388 let request = into_open_ai(
389 request,
390 &self.model.name,
391 self.model.capabilities.parallel_tool_calls,
392 self.model.capabilities.prompt_cache_key,
393 self.max_output_tokens(),
394 None,
395 );
396 let completions = self.stream_completion(request, cx);
397 async move {
398 let mapper = OpenAiEventMapper::new();
399 Ok(mapper.map_stream(completions.await?).boxed())
400 }
401 .boxed()
402 }
403}
404
405struct ConfigurationView {
406 api_key_editor: Entity<SingleLineInput>,
407 state: gpui::Entity<State>,
408 load_credentials_task: Option<Task<()>>,
409}
410
411impl ConfigurationView {
412 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
413 let api_key_editor = cx.new(|cx| {
414 SingleLineInput::new(
415 window,
416 cx,
417 "000000000000000000000000000000000000000000000000000",
418 )
419 });
420
421 cx.observe(&state, |_, _, cx| {
422 cx.notify();
423 })
424 .detach();
425
426 let load_credentials_task = Some(cx.spawn_in(window, {
427 let state = state.clone();
428 async move |this, cx| {
429 if let Some(task) = state
430 .update(cx, |state, cx| state.authenticate(cx))
431 .log_err()
432 {
433 // We don't log an error, because "not signed in" is also an error.
434 let _ = task.await;
435 }
436 this.update(cx, |this, cx| {
437 this.load_credentials_task = None;
438 cx.notify();
439 })
440 .log_err();
441 }
442 }));
443
444 Self {
445 api_key_editor,
446 state,
447 load_credentials_task,
448 }
449 }
450
451 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
452 let api_key = self
453 .api_key_editor
454 .read(cx)
455 .editor()
456 .read(cx)
457 .text(cx)
458 .trim()
459 .to_string();
460
461 // Don't proceed if no API key is provided and we're not authenticated
462 if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
463 return;
464 }
465
466 let state = self.state.clone();
467 cx.spawn_in(window, async move |_, cx| {
468 state
469 .update(cx, |state, cx| state.set_api_key(api_key, cx))?
470 .await
471 })
472 .detach_and_log_err(cx);
473
474 cx.notify();
475 }
476
477 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
478 self.api_key_editor.update(cx, |input, cx| {
479 input.editor.update(cx, |editor, cx| {
480 editor.set_text("", window, cx);
481 });
482 });
483
484 let state = self.state.clone();
485 cx.spawn_in(window, async move |_, cx| {
486 state.update(cx, |state, cx| state.reset_api_key(cx))?.await
487 })
488 .detach_and_log_err(cx);
489
490 cx.notify();
491 }
492
493 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
494 !self.state.read(cx).is_authenticated()
495 }
496}
497
498impl Render for ConfigurationView {
499 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
500 let env_var_set = self.state.read(cx).api_key_from_env;
501 let env_var_name = self.state.read(cx).env_var_name.clone();
502
503 let api_key_section = if self.should_render_editor(cx) {
504 v_flex()
505 .on_action(cx.listener(Self::save_api_key))
506 .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
507 .child(
508 div()
509 .pt(DynamicSpacing::Base04.rems(cx))
510 .child(self.api_key_editor.clone())
511 )
512 .child(
513 Label::new(
514 format!("You can also assign the {env_var_name} environment variable and restart Zed."),
515 )
516 .size(LabelSize::Small).color(Color::Muted),
517 )
518 .into_any()
519 } else {
520 h_flex()
521 .mt_1()
522 .p_1()
523 .justify_between()
524 .rounded_md()
525 .border_1()
526 .border_color(cx.theme().colors().border)
527 .bg(cx.theme().colors().background)
528 .child(
529 h_flex()
530 .gap_1()
531 .child(Icon::new(IconName::Check).color(Color::Success))
532 .child(Label::new(if env_var_set {
533 format!("API key set in {env_var_name} environment variable.")
534 } else {
535 "API key configured.".to_string()
536 })),
537 )
538 .child(
539 Button::new("reset-api-key", "Reset API Key")
540 .label_size(LabelSize::Small)
541 .icon(IconName::Undo)
542 .icon_size(IconSize::Small)
543 .icon_position(IconPosition::Start)
544 .layer(ElevationIndex::ModalSurface)
545 .when(env_var_set, |this| {
546 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
547 })
548 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
549 )
550 .into_any()
551 };
552
553 if self.load_credentials_task.is_some() {
554 div().child(Label::new("Loading credentials…")).into_any()
555 } else {
556 v_flex().size_full().child(api_key_section).into_any()
557 }
558 }
559}