open_ai_compatible.rs

  1use anyhow::{Result, anyhow};
  2use convert_case::{Case, Casing};
  3use futures::{FutureExt, StreamExt, future, future::BoxFuture};
  4use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  5use http_client::HttpClient;
  6use language_model::{
  7    ApiKeyState, AuthenticateError, EnvVar, LanguageModel, LanguageModelCompletionError,
  8    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  9    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 10    LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolSchemaFormat, RateLimiter,
 11};
 12use menu;
 13use open_ai::{ResponseStreamEvent, stream_completion};
 14use settings::{Settings, SettingsStore};
 15use std::sync::Arc;
 16use ui::{ElevationIndex, Tooltip, prelude::*};
 17use ui_input::InputField;
 18use util::ResultExt;
 19
 20use crate::provider::open_ai::{OpenAiEventMapper, into_open_ai};
 21pub use settings::OpenAiCompatibleAvailableModel as AvailableModel;
 22pub use settings::OpenAiCompatibleModelCapabilities as ModelCapabilities;
 23
 24#[derive(Default, Clone, Debug, PartialEq)]
 25pub struct OpenAiCompatibleSettings {
 26    pub api_url: String,
 27    pub available_models: Vec<AvailableModel>,
 28}
 29
 30pub struct OpenAiCompatibleLanguageModelProvider {
 31    id: LanguageModelProviderId,
 32    name: LanguageModelProviderName,
 33    http_client: Arc<dyn HttpClient>,
 34    state: Entity<State>,
 35}
 36
 37pub struct State {
 38    id: Arc<str>,
 39    api_key_state: ApiKeyState,
 40    settings: OpenAiCompatibleSettings,
 41}
 42
 43impl State {
 44    fn is_authenticated(&self) -> bool {
 45        self.api_key_state.has_key()
 46    }
 47
 48    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 49        let api_url = SharedString::new(self.settings.api_url.as_str());
 50        self.api_key_state
 51            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 52    }
 53
 54    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 55        let api_url = SharedString::new(self.settings.api_url.clone());
 56        self.api_key_state
 57            .load_if_needed(api_url, |this| &mut this.api_key_state, cx)
 58    }
 59}
 60
 61impl OpenAiCompatibleLanguageModelProvider {
 62    pub fn new(id: Arc<str>, http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 63        fn resolve_settings<'a>(id: &'a str, cx: &'a App) -> Option<&'a OpenAiCompatibleSettings> {
 64            crate::AllLanguageModelSettings::get_global(cx)
 65                .openai_compatible
 66                .get(id)
 67        }
 68
 69        let api_key_env_var_name = format!("{}_API_KEY", id).to_case(Case::UpperSnake).into();
 70        let state = cx.new(|cx| {
 71            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 72                let Some(settings) = resolve_settings(&this.id, cx).cloned() else {
 73                    return;
 74                };
 75                if &this.settings != &settings {
 76                    let api_url = SharedString::new(settings.api_url.as_str());
 77                    this.api_key_state.handle_url_change(
 78                        api_url,
 79                        |this| &mut this.api_key_state,
 80                        cx,
 81                    );
 82                    this.settings = settings;
 83                    cx.notify();
 84                }
 85            })
 86            .detach();
 87            let settings = resolve_settings(&id, cx).cloned().unwrap_or_default();
 88            State {
 89                id: id.clone(),
 90                api_key_state: ApiKeyState::new(
 91                    SharedString::new(settings.api_url.as_str()),
 92                    EnvVar::new(api_key_env_var_name),
 93                ),
 94                settings,
 95            }
 96        });
 97
 98        Self {
 99            id: id.clone().into(),
100            name: id.into(),
101            http_client,
102            state,
103        }
104    }
105
106    fn create_language_model(&self, model: AvailableModel) -> Arc<dyn LanguageModel> {
107        Arc::new(OpenAiCompatibleLanguageModel {
108            id: LanguageModelId::from(model.name.clone()),
109            provider_id: self.id.clone(),
110            provider_name: self.name.clone(),
111            model,
112            state: self.state.clone(),
113            http_client: self.http_client.clone(),
114            request_limiter: RateLimiter::new(4),
115        })
116    }
117}
118
119impl LanguageModelProviderState for OpenAiCompatibleLanguageModelProvider {
120    type ObservableEntity = State;
121
122    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
123        Some(self.state.clone())
124    }
125}
126
127impl LanguageModelProvider for OpenAiCompatibleLanguageModelProvider {
128    fn id(&self) -> LanguageModelProviderId {
129        self.id.clone()
130    }
131
132    fn name(&self) -> LanguageModelProviderName {
133        self.name.clone()
134    }
135
136    fn icon(&self) -> IconName {
137        IconName::AiOpenAiCompat
138    }
139
140    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
141        self.state
142            .read(cx)
143            .settings
144            .available_models
145            .first()
146            .map(|model| self.create_language_model(model.clone()))
147    }
148
149    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
150        None
151    }
152
153    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
154        self.state
155            .read(cx)
156            .settings
157            .available_models
158            .iter()
159            .map(|model| self.create_language_model(model.clone()))
160            .collect()
161    }
162
163    fn is_authenticated(&self, cx: &App) -> bool {
164        self.state.read(cx).is_authenticated()
165    }
166
167    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
168        self.state.update(cx, |state, cx| state.authenticate(cx))
169    }
170
171    fn configuration_view(
172        &self,
173        _target_agent: language_model::ConfigurationViewTargetAgent,
174        window: &mut Window,
175        cx: &mut App,
176    ) -> AnyView {
177        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
178            .into()
179    }
180
181    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
182        self.state
183            .update(cx, |state, cx| state.set_api_key(None, cx))
184    }
185}
186
187pub struct OpenAiCompatibleLanguageModel {
188    id: LanguageModelId,
189    provider_id: LanguageModelProviderId,
190    provider_name: LanguageModelProviderName,
191    model: AvailableModel,
192    state: Entity<State>,
193    http_client: Arc<dyn HttpClient>,
194    request_limiter: RateLimiter,
195}
196
197impl OpenAiCompatibleLanguageModel {
198    fn stream_completion(
199        &self,
200        request: open_ai::Request,
201        cx: &AsyncApp,
202    ) -> BoxFuture<
203        'static,
204        Result<
205            futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
206            LanguageModelCompletionError,
207        >,
208    > {
209        let http_client = self.http_client.clone();
210
211        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| {
212            let api_url = &state.settings.api_url;
213            (
214                state.api_key_state.key(api_url),
215                state.settings.api_url.clone(),
216            )
217        }) else {
218            return future::ready(Err(anyhow!("App state dropped").into())).boxed();
219        };
220
221        let provider = self.provider_name.clone();
222        let future = self.request_limiter.stream(async move {
223            let Some(api_key) = api_key else {
224                return Err(LanguageModelCompletionError::NoApiKey { provider });
225            };
226            let request = stream_completion(
227                http_client.as_ref(),
228                provider.0.as_str(),
229                &api_url,
230                &api_key,
231                request,
232            );
233            let response = request.await?;
234            Ok(response)
235        });
236
237        async move { Ok(future.await?.boxed()) }.boxed()
238    }
239}
240
241impl LanguageModel for OpenAiCompatibleLanguageModel {
242    fn id(&self) -> LanguageModelId {
243        self.id.clone()
244    }
245
246    fn name(&self) -> LanguageModelName {
247        LanguageModelName::from(
248            self.model
249                .display_name
250                .clone()
251                .unwrap_or_else(|| self.model.name.clone()),
252        )
253    }
254
255    fn provider_id(&self) -> LanguageModelProviderId {
256        self.provider_id.clone()
257    }
258
259    fn provider_name(&self) -> LanguageModelProviderName {
260        self.provider_name.clone()
261    }
262
263    fn supports_tools(&self) -> bool {
264        self.model.capabilities.tools
265    }
266
267    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
268        LanguageModelToolSchemaFormat::JsonSchemaSubset
269    }
270
271    fn supports_images(&self) -> bool {
272        self.model.capabilities.images
273    }
274
275    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
276        match choice {
277            LanguageModelToolChoice::Auto => self.model.capabilities.tools,
278            LanguageModelToolChoice::Any => self.model.capabilities.tools,
279            LanguageModelToolChoice::None => true,
280        }
281    }
282
283    fn telemetry_id(&self) -> String {
284        format!("openai/{}", self.model.name)
285    }
286
287    fn max_token_count(&self) -> u64 {
288        self.model.max_tokens
289    }
290
291    fn max_output_tokens(&self) -> Option<u64> {
292        self.model.max_output_tokens
293    }
294
295    fn count_tokens(
296        &self,
297        request: LanguageModelRequest,
298        cx: &App,
299    ) -> BoxFuture<'static, Result<u64>> {
300        let max_token_count = self.max_token_count();
301        cx.background_spawn(async move {
302            let messages = super::open_ai::collect_tiktoken_messages(request);
303            let model = if max_token_count >= 100_000 {
304                // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
305                "gpt-4o"
306            } else {
307                // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
308                // supported with this tiktoken method
309                "gpt-4"
310            };
311            tiktoken_rs::num_tokens_from_messages(model, &messages).map(|tokens| tokens as u64)
312        })
313        .boxed()
314    }
315
316    fn stream_completion(
317        &self,
318        request: LanguageModelRequest,
319        cx: &AsyncApp,
320    ) -> BoxFuture<
321        'static,
322        Result<
323            futures::stream::BoxStream<
324                'static,
325                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
326            >,
327            LanguageModelCompletionError,
328        >,
329    > {
330        let request = into_open_ai(
331            request,
332            &self.model.name,
333            self.model.capabilities.parallel_tool_calls,
334            self.model.capabilities.prompt_cache_key,
335            self.max_output_tokens(),
336            None,
337        );
338        let completions = self.stream_completion(request, cx);
339        async move {
340            let mapper = OpenAiEventMapper::new();
341            Ok(mapper.map_stream(completions.await?).boxed())
342        }
343        .boxed()
344    }
345}
346
347struct ConfigurationView {
348    api_key_editor: Entity<InputField>,
349    state: Entity<State>,
350    load_credentials_task: Option<Task<()>>,
351}
352
353impl ConfigurationView {
354    fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
355        let api_key_editor = cx.new(|cx| {
356            InputField::new(
357                window,
358                cx,
359                "000000000000000000000000000000000000000000000000000",
360            )
361        });
362
363        cx.observe(&state, |_, _, cx| {
364            cx.notify();
365        })
366        .detach();
367
368        let load_credentials_task = Some(cx.spawn_in(window, {
369            let state = state.clone();
370            async move |this, cx| {
371                if let Some(task) = state
372                    .update(cx, |state, cx| state.authenticate(cx))
373                    .log_err()
374                {
375                    // We don't log an error, because "not signed in" is also an error.
376                    let _ = task.await;
377                }
378                this.update(cx, |this, cx| {
379                    this.load_credentials_task = None;
380                    cx.notify();
381                })
382                .log_err();
383            }
384        }));
385
386        Self {
387            api_key_editor,
388            state,
389            load_credentials_task,
390        }
391    }
392
393    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
394        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
395        if api_key.is_empty() {
396            return;
397        }
398
399        // url changes can cause the editor to be displayed again
400        self.api_key_editor
401            .update(cx, |input, cx| input.set_text("", window, cx));
402
403        let state = self.state.clone();
404        cx.spawn_in(window, async move |_, cx| {
405            state
406                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
407                .await
408        })
409        .detach_and_log_err(cx);
410    }
411
412    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
413        self.api_key_editor
414            .update(cx, |input, cx| input.set_text("", window, cx));
415
416        let state = self.state.clone();
417        cx.spawn_in(window, async move |_, cx| {
418            state
419                .update(cx, |state, cx| state.set_api_key(None, cx))?
420                .await
421        })
422        .detach_and_log_err(cx);
423    }
424
425    fn should_render_editor(&self, cx: &Context<Self>) -> bool {
426        !self.state.read(cx).is_authenticated()
427    }
428}
429
430impl Render for ConfigurationView {
431    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
432        let state = self.state.read(cx);
433        let env_var_set = state.api_key_state.is_from_env_var();
434        let env_var_name = state.api_key_state.env_var_name();
435
436        let api_key_section = if self.should_render_editor(cx) {
437            v_flex()
438                .on_action(cx.listener(Self::save_api_key))
439                .child(Label::new("To use Zed's agent with an OpenAI-compatible provider, you need to add an API key."))
440                .child(
441                    div()
442                        .pt(DynamicSpacing::Base04.rems(cx))
443                        .child(self.api_key_editor.clone())
444                )
445                .child(
446                    Label::new(
447                        format!("You can also assign the {env_var_name} environment variable and restart Zed."),
448                    )
449                    .size(LabelSize::Small).color(Color::Muted),
450                )
451                .into_any()
452        } else {
453            h_flex()
454                .mt_1()
455                .p_1()
456                .justify_between()
457                .rounded_md()
458                .border_1()
459                .border_color(cx.theme().colors().border)
460                .bg(cx.theme().colors().background)
461                .child(
462                    h_flex()
463                        .flex_1()
464                        .min_w_0()
465                        .gap_1()
466                        .child(Icon::new(IconName::Check).color(Color::Success))
467                        .child(
468                            div()
469                                .w_full()
470                                .overflow_x_hidden()
471                                .text_ellipsis()
472                                .child(Label::new(
473                                    if env_var_set {
474                                        format!("API key set in {env_var_name} environment variable")
475                                    } else {
476                                        format!("API key configured for {}", &state.settings.api_url)
477                                    }
478                                ))
479                        ),
480                )
481                .child(
482                    h_flex()
483                        .flex_shrink_0()
484                        .child(
485                            Button::new("reset-api-key", "Reset API Key")
486                                .label_size(LabelSize::Small)
487                                .icon(IconName::Undo)
488                                .icon_size(IconSize::Small)
489                                .icon_position(IconPosition::Start)
490                                .layer(ElevationIndex::ModalSurface)
491                                .when(env_var_set, |this| {
492                                    this.tooltip(Tooltip::text(format!("To reset your API key, unset the {env_var_name} environment variable.")))
493                                })
494                                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
495                        ),
496                )
497                .into_any()
498        };
499
500        if self.load_credentials_task.is_some() {
501            div().child(Label::new("Loading credentials…")).into_any()
502        } else {
503            v_flex().size_full().child(api_key_section).into_any()
504        }
505    }
506}