open_ai_compatible.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use credentials_provider::CredentialsProvider;
  3
  4use convert_case::{Case, Casing};
  5use futures::{FutureExt, StreamExt, future::BoxFuture};
  6use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
  7use http_client::HttpClient;
  8use language_model::{
  9    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 10    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 11    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 12    LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
 13};
 14use menu;
 15use open_ai::{ResponseStreamEvent, stream_completion};
 16use settings::{Settings, SettingsStore};
 17use std::sync::Arc;
 18
 19use ui::{ElevationIndex, Tooltip, prelude::*};
 20use ui_input::SingleLineInput;
 21use util::ResultExt;
 22
 23use crate::AllLanguageModelSettings;
 24use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
 25pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
 26pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
 27
 28#[derive(Default, Clone, Debug, PartialEq)]
 29pub struct OpenAiCompatibleSettings {
 30    pub api_url: String,
 31    pub available_models: Vec<AvailableModel>,
 32}
 33
 34pub struct OpenAiCompatibleLanguageModelProvider {
 35    id: LanguageModelProviderId,
 36    name: LanguageModelProviderName,
 37    http_client: Arc<dyn HttpClient>,
 38    state: gpui::Entity<State>,
 39}
 40
 41pub struct State {
 42    id: Arc<str>,
 43    env_var_name: Arc<str>,
 44    api_key: Option<String>,
 45    api_key_from_env: bool,
 46    settings: OpenAiCompatibleSettings,
 47    _subscription: Subscription,
 48}
 49
 50impl State {
 51    fn is_authenticated(&self) -> bool {
 52        self.api_key.is_some()
 53    }
 54
 55    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 56        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 57        let api_url = self.settings.api_url.clone();
 58        cx.spawn(async move |this, cx| {
 59            credentials_provider
 60                .delete_credentials(&api_url, cx)
 61                .await
 62                .log_err();
 63            this.update(cx, |this, cx| {
 64                this.api_key = None;
 65                this.api_key_from_env = false;
 66                cx.notify();
 67            })
 68        })
 69    }
 70
 71    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 72        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 73        let api_url = self.settings.api_url.clone();
 74        cx.spawn(async move |this, cx| {
 75            credentials_provider
 76                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
 77                .await
 78                .log_err();
 79            this.update(cx, |this, cx| {
 80                this.api_key = Some(api_key);
 81                cx.notify();
 82            })
 83        })
 84    }
 85
 86    fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 87        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 88        let env_var_name = self.env_var_name.clone();
 89        let api_url = self.settings.api_url.clone();
 90        cx.spawn(async move |this, cx| {
 91            let (api_key, from_env) = if let Ok(api_key) = std::env::var(env_var_name.as_ref()) {
 92                (api_key, true)
 93            } else {
 94                let (_, api_key) = credentials_provider
 95                    .read_credentials(&api_url, cx)
 96                    .await?
 97                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 98                (
 99                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
100                    false,
101                )
102            };
103            this.update(cx, |this, cx| {
104                this.api_key = Some(api_key);
105                this.api_key_from_env = from_env;
106                cx.notify();
107            })?;
108
109            Ok(())
110        })
111    }
112
113    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
114        if self.is_authenticated() {
115            return Task::ready(Ok(()));
116        }
117
118        self.get_api_key(cx)
119    }
120}
121
122impl OpenAiCompatibleLanguageModelProvider {
123    pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
124        fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
125            AllLanguageModelSettings::get_global(cx)
126                .openai_compatible
127                .get(id)
128        }
129
130        let state = cx.new(|cx| State {
131            id: id.clone(),
132            env_var_name: format!("{}_API_KEY", id).to_case(Case::Constant).into(),
133            settings: resolve_settings(&id, cx).cloned().unwrap_or_default(),
134            api_key: None,
135            api_key_from_env: false,
136            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
137                let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
138                    return;
139                };
140                if &this.settings != &settings {
141                    if settings.api_url != this.settings.api_url && !this.api_key_from_env {
142                        let spawn_task = cx.spawn(async move |handle, cx| {
143                            if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
144                                if let Err(_) = task.await {
145                                    handle
146                                        .update(cx, |this, _| {
147                                            this.api_key = None;
148                                            this.api_key_from_env = false;
149                                        })
150                                        .ok();
151                                }
152                            }
153                        });
154                        spawn_task.detach();
155                    }
156
157                    this.settings = settings;
158                    cx.notify();
159                }
160            }),
161        });
162
163        Self {
164            id: id.clone().into(),
165            name: id.into(),
166            http_client,
167            state,
168        }
169    }
170
171    fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
172        Arc::new(OpenAiCompatibleLanguageModel {
173            id: LanguageModelId::from(model.name.clone()),
174            provider_id: self.id.clone(),
175            provider_name: self.name.clone(),
176            model,
177            state: self.state.clone(),
178            http_client: self.http_client.clone(),
179            request_limiter: RateLimiter::new(4),
180        })
181    }
182}
183
184impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
185    type ObservableEntity = State;
186
187    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
188        Some(self.state.clone())
189    }
190}
191
192impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
193    fn id(&self) -> LanguageModelProviderId {
194        self.id.clone()
195    }
196
197    fn name(&self) -> LanguageModelProviderName {
198        self.name.clone()
199    }
200
201    fn icon(&self) -> IconName {
202        IconName::AiOpenAiCompat
203    }
204
205    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
206        self.state
207            .read(cx)
208            .settings
209            .available_models
210            .first()
211            .map(|model| self.create_language_model(model.clone()))
212    }
213
214    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
215        None
216    }
217
218    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
219        self.state
220            .read(cx)
221            .settings
222            .available_models
223            .iter()
224            .map(|model| self.create_language_model(model.clone()))
225            .collect()
226    }
227
228    fn is_authenticated(&self, cx: &App) -> bool {
229        self.state.read(cx).is_authenticated()
230    }
231
232    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
233        self.state.update(cx, |state, cx| state.authenticate(cx))
234    }
235
236    fn configuration_view(
237        &self,
238        _target_agent: language_model::ConfigurationViewTargetAgent,
239        window: &mut Window,
240        cx: &mut App,
241    ) -> AnyView {
242        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
243            .into()
244    }
245
246    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
247        self.state.update(cx, |state, cx| state.reset_api_key(cx))
248    }
249}
250
251pub struct OpenAiCompatibleLanguageModel {
252    id: LanguageModelId,
253    provider_id: LanguageModelProviderId,
254    provider_name: LanguageModelProviderName,
255    model: AvailableModel,
256    state: gpui::Entity<State>,
257    http_client: Arc<dyn HttpClient>,
258    request_limiter: RateLimiter,
259}
260
261impl OpenAiCompatibleLanguageModel {
262    fn stream_completion(
263        &self,
264        request: open_ai::Request,
265        cx: &AsyncApp,
266    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
267    {
268        let http_client = self.http_client.clone();
269        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, _| {
270            (state.api_key.clone(), state.settings.api_url.clone())
271        }) else {
272            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
273        };
274
275        let provider = self.provider_name.clone();
276        let future = self.request_limiter.stream(async move {
277            let Some(api_key) = api_key else {
278                return Err(LanguageModelCompletionError::NoApiKey { provider });
279            };
280            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
281            let response = request.await?;
282            Ok(response)
283        });
284
285        async move { Ok(future.await?.boxed()) }.boxed()
286    }
287}
288
289impl LanguageModel for OpenAiCompatibleLanguageModel {
290    fn id(&self) -> LanguageModelId {
291        self.id.clone()
292    }
293
294    fn name(&self) -> LanguageModelName {
295        LanguageModelName::from(
296            self.model
297                .display_name
298                .clone()
299                .unwrap_or_else(|| self.model.name.clone()),
300        )
301    }
302
303    fn provider_id(&self) -> LanguageModelProviderId {
304        self.provider_id.clone()
305    }
306
307    fn provider_name(&self) -> LanguageModelProviderName {
308        self.provider_name.clone()
309    }
310
311    fn supports_tools(&self) -> bool {
312        self.model.capabilities.tools
313    }
314
315    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
316        LanguageModelToolSchemaFormat::JsonSchemaSubset
317    }
318
319    fn supports_images(&self) -> bool {
320        self.model.capabilities.images
321    }
322
323    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
324        match choice {
325            LanguageModelToolChoice::Auto => self.model.capabilities.tools,
326            LanguageModelToolChoice::Any => self.model.capabilities.tools,
327            LanguageModelToolChoice::None => true,
328        }
329    }
330
331    fn telemetry_id(&self) -> String {
332        format!("openai/{}", self.model.name)
333    }
334
335    fn max_token_count(&self) -> u64 {
336        self.model.max_tokens
337    }
338
339    fn max_output_tokens(&self) -> Option<u64> {
340        self.model.max_output_tokens
341    }
342
343    fn count_tokens(
344        &self,
345        request: LanguageModelRequest,
346        cx: &App,
347    ) -> BoxFuture<'static, Result<u64>> {
348        let max_token_count = self.max_token_count();
349        cx.background_spawn(async move {
350            let messages = super::open_ai::collect_tiktoken_messages(request);
351            let model = if max_token_count >= 100_000 {
352                // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
353                "gpt-4o"
354            } else {
355                // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
356                // supported with this tiktoken method
357                "gpt-4"
358            };
359            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
360        })
361        .boxed()
362    }
363
364    fn stream_completion(
365        &self,
366        request: LanguageModelRequest,
367        cx: &AsyncApp,
368    ) -> BoxFuture<
369        'static,
370        Result<
371            futures::stream::BoxStream<
372                'static,
373                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
374            >,
375            LanguageModelCompletionError,
376        >,
377    > {
378        let request = into_open_ai(
379            request,
380            &self.model.name,
381            self.model.capabilities.parallel_tool_calls,
382            self.model.capabilities.prompt_cache_key,
383            self.max_output_tokens(),
384            None,
385        );
386        let completions = self.stream_completion(request, cx);
387        async move {
388            let mapper = OpenAiEventMapper::new();
389            Ok(mapper.map_stream(completions.await?).boxed())
390        }
391        .boxed()
392    }
393}
394
395struct ConfigurationView {
396    api_key_editor: Entity<SingleLineInput>,
397    state: gpui::Entity<State>,
398    load_credentials_task: Option<Task<()>>,
399}
400
401impl ConfigurationView {
402    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
403        let api_key_editor = cx.new(|cx| {
404            SingleLineInput::new(
405                window,
406                cx,
407                "000000000000000000000000000000000000000000000000000",
408            )
409        });
410
411        cx.observe(&state, |_, _, cx| {
412            cx.notify();
413        })
414        .detach();
415
416        let load_credentials_task = Some(cx.spawn_in(window, {
417            let state = state.clone();
418            async move |this, cx| {
419                if let Some(task) = state
420                    .update(cx, |state, cx| state.authenticate(cx))
421                    .log_err()
422                {
423                    // We don't log an error, because "not signed in" is also an error.
424                    let _ = task.await;
425                }
426                this.update(cx, |this, cx| {
427                    this.load_credentials_task = None;
428                    cx.notify();
429                })
430                .log_err();
431            }
432        }));
433
434        Self {
435            api_key_editor,
436            state,
437            load_credentials_task,
438        }
439    }
440
441    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
442        let api_key = self
443            .api_key_editor
444            .read(cx)
445            .editor()
446            .read(cx)
447            .text(cx)
448            .trim()
449            .to_string();
450
451        // Don't proceed if no API key is provided and we're not authenticated
452        if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
453            return;
454        }
455
456        let state = self.state.clone();
457        cx.spawn_in(window, async move |_, cx| {
458            state
459                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
460                .await
461        })
462        .detach_and_log_err(cx);
463
464        cx.notify();
465    }
466
467    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
468        self.api_key_editor.update(cx, |input, cx| {
469            input.editor.update(cx, |editor, cx| {
470                editor.set_text("", window, cx);
471            });
472        });
473
474        let state = self.state.clone();
475        cx.spawn_in(window, async move |_, cx| {
476            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
477        })
478        .detach_and_log_err(cx);
479
480        cx.notify();
481    }
482
483    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
484        !self.state.read(cx).is_authenticated()
485    }
486}
487
488impl Render for ConfigurationView {
489    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
490        let env_var_set = self.state.read(cx).api_key_from_env;
491        let env_var_name = self.state.read(cx).env_var_name.clone();
492
493        let api_key_section = if self.should_render_editor(cx) {
494            v_flex()
495                .on_action(cx.listener(Self::save_api_key))
496                .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
497                .child(
498                    div()
499                        .pt(DynamicSpacing::Base04.rems(cx))
500                        .child(self.api_key_editor.clone())
501                )
502                .child(
503                    Label::new(
504                        format!("You can also assign the {env_var_name} environment variable and restart Zed."),
505                    )
506                    .size(LabelSize::Small).color(Color::Muted),
507                )
508                .into_any()
509        } else {
510            h_flex()
511                .mt_1()
512                .p_1()
513                .justify_between()
514                .rounded_md()
515                .border_1()
516                .border_color(cx.theme().colors().border)
517                .bg(cx.theme().colors().background)
518                .child(
519                    h_flex()
520                        .gap_1()
521                        .child(Icon::new(IconName::Check).color(Color::Success))
522                        .child(Label::new(if env_var_set {
523                            format!("API key set in {env_var_name} environment variable.")
524                        } else {
525                            "API key configured.".to_string()
526                        })),
527                )
528                .child(
529                    Button::new("reset-api-key", "Reset API Key")
530                        .label_size(LabelSize::Small)
531                        .icon(IconName::Undo)
532                        .icon_size(IconSize::Small)
533                        .icon_position(IconPosition::Start)
534                        .layer(ElevationIndex::ModalSurface)
535                        .when(env_var_set, |this| {
536                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
537                        })
538                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
539                )
540                .into_any()
541        };
542
543        if self.load_credentials_task.is_some() {
544            div().child(Label::new("Loading credentials…")).into_any()
545        } else {
546            v_flex().size_full().child(api_key_section).into_any()
547        }
548    }
549}