1use anyhow::{Context as _, Result, anyhow};
2use credentials_provider::CredentialsProvider;
3
4use convert_case::{Case, Casing};
5use futures::{FutureExt, StreamExt, future::BoxFuture};
6use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
7use http_client::HttpClient;
8use language_model::{
9 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
10 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
11 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
12 LanguageModelToolChoice, RateLimiter,
13};
14use menu;
15use open_ai::{ResponseStreamEvent, stream_completion};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::sync::Arc;
20
21use ui::{ElevationIndex, Tooltip, prelude::*};
22use ui_input::SingleLineInput;
23use util::ResultExt;
24
25use crate::AllLanguageModelSettings;
26use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
27
28#[derive(Default, Clone, Debug, PartialEq)]
29pub struct OpenAiCompatibleSettings {
30 pub api_url: String,
31 pub available_models: Vec<AvailableModel>,
32}
33
34#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
35pub struct AvailableModel {
36 pub name: String,
37 pub display_name: Option<String>,
38 pub max_tokens: u64,
39 pub max_output_tokens: Option<u64>,
40 pub max_completion_tokens: Option<u64>,
41 #[serde(default)]
42 pub capabilities: ModelCapabilities,
43}
44
45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
46pub struct ModelCapabilities {
47 pub tools: bool,
48 pub images: bool,
49 pub parallel_tool_calls: bool,
50 pub prompt_cache_key: bool,
51}
52
53impl Default for ModelCapabilities {
54 fn default() -> Self {
55 Self {
56 tools: true,
57 images: false,
58 parallel_tool_calls: false,
59 prompt_cache_key: false,
60 }
61 }
62}
63
64pub struct OpenAiCompatibleLanguageModelProvider {
65 id: LanguageModelProviderId,
66 name: LanguageModelProviderName,
67 http_client: Arc<dyn HttpClient>,
68 state: gpui::Entity<State>,
69}
70
71pub struct State {
72 id: Arc<str>,
73 env_var_name: Arc<str>,
74 api_key: Option<String>,
75 api_key_from_env: bool,
76 settings: OpenAiCompatibleSettings,
77 _subscription: Subscription,
78}
79
80impl State {
81 fn is_authenticated(&self) -> bool {
82 self.api_key.is_some()
83 }
84
85 fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
86 let credentials_provider = <dyn CredentialsProvider>::global(cx);
87 let api_url = self.settings.api_url.clone();
88 cx.spawn(async move |this, cx| {
89 credentials_provider
90 .delete_credentials(&api_url, cx)
91 .await
92 .log_err();
93 this.update(cx, |this, cx| {
94 this.api_key = None;
95 this.api_key_from_env = false;
96 cx.notify();
97 })
98 })
99 }
100
101 fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
102 let credentials_provider = <dyn CredentialsProvider>::global(cx);
103 let api_url = self.settings.api_url.clone();
104 cx.spawn(async move |this, cx| {
105 credentials_provider
106 .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
107 .await
108 .log_err();
109 this.update(cx, |this, cx| {
110 this.api_key = Some(api_key);
111 cx.notify();
112 })
113 })
114 }
115
116 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
117 if self.is_authenticated() {
118 return Task::ready(Ok(()));
119 }
120
121 let credentials_provider = <dyn CredentialsProvider>::global(cx);
122 let env_var_name = self.env_var_name.clone();
123 let api_url = self.settings.api_url.clone();
124 cx.spawn(async move |this, cx| {
125 let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
126 (api_key, true)
127 } else {
128 let (_, api_key) = credentials_provider
129 .read_credentials(&api_url, cx)
130 .await?
131 .ok_or(AuthenticateError::CredentialsNotFound)?;
132 (
133 String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
134 false,
135 )
136 };
137 this.update(cx, |this, cx| {
138 this.api_key = Some(api_key);
139 this.api_key_from_env = from_env;
140 cx.notify();
141 })?;
142
143 Ok(())
144 })
145 }
146}
147
148impl OpenAiCompatibleLanguageModelProvider {
149 pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
150 fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
151 AllLanguageModelSettings::get_global(cx)
152 .openai_compatible
153 .get(id)
154 }
155
156 let state = cx.new(|cx| State {
157 id: id.clone(),
158 env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
159 settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
160 api_key: None,
161 api_key_from_env: false,
162 _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
163 let Some(settings) = resolve_settings(&this.id, cx) else {
164 return;
165 };
166 if &this.settings != settings {
167 this.settings = settings.clone();
168 cx.notify();
169 }
170 }),
171 });
172
173 Self {
174 id: id.clone().into(),
175 name: id.into(),
176 http_client,
177 state,
178 }
179 }
180
181 fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
182 Arc::new(OpenAiCompatibleLanguageModel {
183 id: LanguageModelId::from(model.name.clone()),
184 provider_id: self.id.clone(),
185 provider_name: self.name.clone(),
186 model,
187 state: self.state.clone(),
188 http_client: self.http_client.clone(),
189 request_limiter: RateLimiter::new(4),
190 })
191 }
192}
193
194impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
195 type ObservableEntity = State;
196
197 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
198 Some(self.state.clone())
199 }
200}
201
202impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
203 fn id(&self) -> LanguageModelProviderId {
204 self.id.clone()
205 }
206
207 fn name(&self) -> LanguageModelProviderName {
208 self.name.clone()
209 }
210
211 fn icon(&self) -> IconName {
212 IconName::AiOpenAiCompat
213 }
214
215 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
216 self.state
217 .read(cx)
218 .settings
219 .available_models
220 .first()
221 .map(|model| self.create_language_model(model.clone()))
222 }
223
224 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
225 None
226 }
227
228 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
229 self.state
230 .read(cx)
231 .settings
232 .available_models
233 .iter()
234 .map(|model| self.create_language_model(model.clone()))
235 .collect()
236 }
237
238 fn is_authenticated(&self, cx: &App) -> bool {
239 self.state.read(cx).is_authenticated()
240 }
241
242 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
243 self.state.update(cx, |state, cx| state.authenticate(cx))
244 }
245
246 fn configuration_view(
247 &self,
248 _target_agent: language_model::ConfigurationViewTargetAgent,
249 window: &mut Window,
250 cx: &mut App,
251 ) -> AnyView {
252 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
253 .into()
254 }
255
256 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
257 self.state.update(cx, |state, cx| state.reset_api_key(cx))
258 }
259}
260
261pub struct OpenAiCompatibleLanguageModel {
262 id: LanguageModelId,
263 provider_id: LanguageModelProviderId,
264 provider_name: LanguageModelProviderName,
265 model: AvailableModel,
266 state: gpui::Entity<State>,
267 http_client: Arc<dyn HttpClient>,
268 request_limiter: RateLimiter,
269}
270
271impl OpenAiCompatibleLanguageModel {
272 fn stream_completion(
273 &self,
274 request: open_ai::Request,
275 cx: &AsyncApp,
276 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
277 {
278 let http_client = self.http_client.clone();
279 let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
280 (state.api_key.clone(), state.settings.api_url.clone())
281 }) else {
282 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
283 };
284
285 let provider = self.provider_name.clone();
286 let future = self.request_limiter.stream(async move {
287 let Some(api_key) = api_key else {
288 return Err(LanguageModelCompletionError::NoApiKey { provider });
289 };
290 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
291 let response = request.await?;
292 Ok(response)
293 });
294
295 async move { Ok(future.await?.boxed()) }.boxed()
296 }
297}
298
299impl LanguageModel for OpenAiCompatibleLanguageModel {
300 fn id(&self) -> LanguageModelId {
301 self.id.clone()
302 }
303
304 fn name(&self) -> LanguageModelName {
305 LanguageModelName::from(
306 self.model
307 .display_name
308 .clone()
309 .unwrap_or_else(|| self.model.name.clone()),
310 )
311 }
312
313 fn provider_id(&self) -> LanguageModelProviderId {
314 self.provider_id.clone()
315 }
316
317 fn provider_name(&self) -> LanguageModelProviderName {
318 self.provider_name.clone()
319 }
320
321 fn supports_tools(&self) -> bool {
322 self.model.capabilities.tools
323 }
324
325 fn supports_images(&self) -> bool {
326 self.model.capabilities.images
327 }
328
329 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
330 match choice {
331 LanguageModelToolChoice::Auto => self.model.capabilities.tools,
332 LanguageModelToolChoice::Any => self.model.capabilities.tools,
333 LanguageModelToolChoice::None => true,
334 }
335 }
336
337 fn telemetry_id(&self) -> String {
338 format!("openai/{}", self.model.name)
339 }
340
341 fn max_token_count(&self) -> u64 {
342 self.model.max_tokens
343 }
344
345 fn max_output_tokens(&self) -> Option<u64> {
346 self.model.max_output_tokens
347 }
348
349 fn count_tokens(
350 &self,
351 request: LanguageModelRequest,
352 cx: &App,
353 ) -> BoxFuture<'static, Result<u64>> {
354 let max_token_count = self.max_token_count();
355 cx.background_spawn(async move {
356 let messages = super::open_ai::collect_tiktoken_messages(request);
357 let model = if max_token_count >= 100_000 {
358 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
359 "gpt-4o"
360 } else {
361 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
362 // supported with this tiktoken method
363 "gpt-4"
364 };
365 tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
366 })
367 .boxed()
368 }
369
370 fn stream_completion(
371 &self,
372 request: LanguageModelRequest,
373 cx: &AsyncApp,
374 ) -> BoxFuture<
375 'static,
376 Result<
377 futures::stream::BoxStream<
378 'static,
379 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
380 >,
381 LanguageModelCompletionError,
382 >,
383 > {
384 let request = into_open_ai(
385 request,
386 &self.model.name,
387 self.model.capabilities.parallel_tool_calls,
388 self.model.capabilities.prompt_cache_key,
389 self.max_output_tokens(),
390 None,
391 );
392 let completions = self.stream_completion(request, cx);
393 async move {
394 let mapper = OpenAiEventMapper::new();
395 Ok(mapper.map_stream(completions.await?).boxed())
396 }
397 .boxed()
398 }
399}
400
401struct ConfigurationView {
402 api_key_editor: Entity<SingleLineInput>,
403 state: gpui::Entity<State>,
404 load_credentials_task: Option<Task<()>>,
405}
406
407impl ConfigurationView {
408 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
409 let api_key_editor = cx.new(|cx| {
410 SingleLineInput::new(
411 window,
412 cx,
413 "000000000000000000000000000000000000000000000000000",
414 )
415 });
416
417 cx.observe(&state, |_, _, cx| {
418 cx.notify();
419 })
420 .detach();
421
422 let load_credentials_task = Some(cx.spawn_in(window, {
423 let state = state.clone();
424 async move |this, cx| {
425 if let Some(task) = state
426 .update(cx, |state, cx| state.authenticate(cx))
427 .log_err()
428 {
429 // We don't log an error, because "not signed in" is also an error.
430 let _ = task.await;
431 }
432 this.update(cx, |this, cx| {
433 this.load_credentials_task = None;
434 cx.notify();
435 })
436 .log_err();
437 }
438 }));
439
440 Self {
441 api_key_editor,
442 state,
443 load_credentials_task,
444 }
445 }
446
447 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
448 let api_key = self
449 .api_key_editor
450 .read(cx)
451 .editor()
452 .read(cx)
453 .text(cx)
454 .trim()
455 .to_string();
456
457 // Don't proceed if no API key is provided and we're not authenticated
458 if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
459 return;
460 }
461
462 let state = self.state.clone();
463 cx.spawn_in(window, async move |_, cx| {
464 state
465 .update(cx, |state, cx| state.set_api_key(api_key, cx))?
466 .await
467 })
468 .detach_and_log_err(cx);
469
470 cx.notify();
471 }
472
473 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
474 self.api_key_editor.update(cx, |input, cx| {
475 input.editor.update(cx, |editor, cx| {
476 editor.set_text("", window, cx);
477 });
478 });
479
480 let state = self.state.clone();
481 cx.spawn_in(window, async move |_, cx| {
482 state.update(cx, |state, cx| state.reset_api_key(cx))?.await
483 })
484 .detach_and_log_err(cx);
485
486 cx.notify();
487 }
488
489 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
490 !self.state.read(cx).is_authenticated()
491 }
492}
493
494impl Render for ConfigurationView {
495 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
496 let env_var_set = self.state.read(cx).api_key_from_env;
497 let env_var_name = self.state.read(cx).env_var_name.clone();
498
499 let api_key_section = if self.should_render_editor(cx) {
500 v_flex()
501 .on_action(cx.listener(Self::save_api_key))
502 .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
503 .child(
504 div()
505 .pt(DynamicSpacing::Base04.rems(cx))
506 .child(self.api_key_editor.clone())
507 )
508 .child(
509 Label::new(
510 format!("You can also assign the {env_var_name} environment variable and restart Zed."),
511 )
512 .size(LabelSize::Small).color(Color::Muted),
513 )
514 .into_any()
515 } else {
516 h_flex()
517 .mt_1()
518 .p_1()
519 .justify_between()
520 .rounded_md()
521 .border_1()
522 .border_color(cx.theme().colors().border)
523 .bg(cx.theme().colors().background)
524 .child(
525 h_flex()
526 .gap_1()
527 .child(Icon::new(IconName::Check).color(Color::Success))
528 .child(Label::new(if env_var_set {
529 format!("API key set in {env_var_name} environment variable.")
530 } else {
531 "API key configured.".to_string()
532 })),
533 )
534 .child(
535 Button::new("reset-api-key", "Reset API Key")
536 .label_size(LabelSize::Small)
537 .icon(IconName::Undo)
538 .icon_size(IconSize::Small)
539 .icon_position(IconPosition::Start)
540 .layer(ElevationIndex::ModalSurface)
541 .when(env_var_set, |this| {
542 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
543 })
544 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
545 )
546 .into_any()
547 };
548
549 if self.load_credentials_task.is_some() {
550 div().child(Label::new("Loading credentials…")).into_any()
551 } else {
552 v_flex().size_full().child(api_key_section).into_any()
553 }
554 }
555}