open_router.rs

  1use anyhow::{Result, anyhow};
  2use collections::HashMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
  5use gpui::{
  6    AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
  7};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 14    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
 15};
 16use open_router::{
 17    Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models,
 18};
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use settings::{Settings, SettingsStore};
 22use std::pin::Pin;
 23use std::str::FromStr as _;
 24use std::sync::{Arc, LazyLock};
 25use theme::ThemeSettings;
 26use ui::{Icon, IconName, List, Tooltip, prelude::*};
 27use util::ResultExt;
 28use zed_env_vars::{EnvVar, env_var};
 29
 30use crate::{AllLanguageModelSettings, api_key::ApiKeyState, ui::InstructionListItem};
 31
 32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
 33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
 34
 35const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
 36static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct OpenRouterSettings {
 40    pub api_url: String,
 41    pub available_models: Vec<AvailableModel>,
 42}
 43
 44#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 45pub struct AvailableModel {
 46    pub name: String,
 47    pub display_name: Option<String>,
 48    pub max_tokens: u64,
 49    pub max_output_tokens: Option<u64>,
 50    pub max_completion_tokens: Option<u64>,
 51    pub supports_tools: Option<bool>,
 52    pub supports_images: Option<bool>,
 53    pub mode: Option<ModelMode>,
 54    pub provider: Option<Provider>,
 55}
 56
 57#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 58#[serde(tag = "type", rename_all = "lowercase")]
 59pub enum ModelMode {
 60    #[default]
 61    Default,
 62    Thinking {
 63        budget_tokens: Option<u32>,
 64    },
 65}
 66
 67impl From<ModelMode> for OpenRouterModelMode {
 68    fn from(value: ModelMode) -> Self {
 69        match value {
 70            ModelMode::Default => OpenRouterModelMode::Default,
 71            ModelMode::Thinking { budget_tokens } => {
 72                OpenRouterModelMode::Thinking { budget_tokens }
 73            }
 74        }
 75    }
 76}
 77
 78impl From<OpenRouterModelMode> for ModelMode {
 79    fn from(value: OpenRouterModelMode) -> Self {
 80        match value {
 81            OpenRouterModelMode::Default => ModelMode::Default,
 82            OpenRouterModelMode::Thinking { budget_tokens } => {
 83                ModelMode::Thinking { budget_tokens }
 84            }
 85        }
 86    }
 87}
 88
 89pub struct OpenRouterLanguageModelProvider {
 90    http_client: Arc<dyn HttpClient>,
 91    state: gpui::Entity<State>,
 92}
 93
 94pub struct State {
 95    api_key_state: ApiKeyState,
 96    http_client: Arc<dyn HttpClient>,
 97    available_models: Vec<open_router::Model>,
 98    fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
 99    settings: OpenRouterSettings,
100}
101
102impl State {
103    fn is_authenticated(&self) -> bool {
104        self.api_key_state.has_key()
105    }
106
107    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
108        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
109        self.api_key_state
110            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
111    }
112
113    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
114        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
115        let task = self.api_key_state.load_if_needed(
116            api_url,
117            &API_KEY_ENV_VAR,
118            |this| &mut this.api_key_state,
119            cx,
120        );
121
122        cx.spawn(async move |this, cx| {
123            let result = task.await;
124            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
125                .ok();
126            result
127        })
128    }
129
130    fn fetch_models(
131        &mut self,
132        cx: &mut Context<Self>,
133    ) -> Task<Result<(), LanguageModelCompletionError>> {
134        let http_client = self.http_client.clone();
135        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
136        let Some(api_key) = self.api_key_state.key(&api_url) else {
137            return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
138                provider: PROVIDER_NAME,
139            }));
140        };
141        cx.spawn(async move |this, cx| {
142            let models = list_models(http_client.as_ref(), &api_url, &api_key)
143                .await
144                .map_err(|e| {
145                    LanguageModelCompletionError::Other(anyhow::anyhow!(
146                        "OpenRouter error: {:?}",
147                        e
148                    ))
149                })?;
150
151            this.update(cx, |this, cx| {
152                this.available_models = models;
153                cx.notify();
154            })
155            .map_err(|e| LanguageModelCompletionError::Other(e))?;
156
157            Ok(())
158        })
159    }
160
161    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
162        if self.is_authenticated() {
163            let task = self.fetch_models(cx);
164            self.fetch_models_task.replace(task);
165        } else {
166            self.available_models = Vec::new();
167        }
168    }
169}
170
171impl OpenRouterLanguageModelProvider {
172    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
173        let state = cx.new(|cx| {
174            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
175                let current_settings = &AllLanguageModelSettings::get_global(cx).open_router;
176                let settings_changed = current_settings != &this.settings;
177                if settings_changed {
178                    this.settings = current_settings.clone();
179                    this.authenticate(cx).detach();
180                }
181                cx.notify();
182            })
183            .detach();
184            State {
185                api_key_state: ApiKeyState::new(Self::api_url(cx)),
186                http_client: http_client.clone(),
187                available_models: Vec::new(),
188                fetch_models_task: None,
189                settings: OpenRouterSettings::default(),
190            }
191        });
192
193        Self { http_client, state }
194    }
195
196    fn settings(cx: &App) -> &OpenRouterSettings {
197        &AllLanguageModelSettings::get_global(cx).open_router
198    }
199
200    fn api_url(cx: &App) -> SharedString {
201        SharedString::new(Self::settings(cx).api_url.as_str())
202    }
203
204    fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
205        Arc::new(OpenRouterLanguageModel {
206            id: LanguageModelId::from(model.id().to_string()),
207            model,
208            state: self.state.clone(),
209            http_client: self.http_client.clone(),
210            request_limiter: RateLimiter::new(4),
211        })
212    }
213}
214
215impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
216    type ObservableEntity = State;
217
218    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
219        Some(self.state.clone())
220    }
221}
222
223impl LanguageModelProvider for OpenRouterLanguageModelProvider {
224    fn id(&self) -> LanguageModelProviderId {
225        PROVIDER_ID
226    }
227
228    fn name(&self) -> LanguageModelProviderName {
229        PROVIDER_NAME
230    }
231
232    fn icon(&self) -> IconName {
233        IconName::AiOpenRouter
234    }
235
236    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
237        Some(self.create_language_model(open_router::Model::default()))
238    }
239
240    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
241        Some(self.create_language_model(open_router::Model::default_fast()))
242    }
243
244    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
245        let mut models_from_api = self.state.read(cx).available_models.clone();
246        let mut settings_models = Vec::new();
247
248        for model in &Self::settings(cx).available_models {
249            settings_models.push(open_router::Model {
250                name: model.name.clone(),
251                display_name: model.display_name.clone(),
252                max_tokens: model.max_tokens,
253                supports_tools: model.supports_tools,
254                supports_images: model.supports_images,
255                mode: model.mode.clone().unwrap_or_default().into(),
256                provider: model.provider.clone(),
257            });
258        }
259
260        for settings_model in &settings_models {
261            if let Some(pos) = models_from_api
262                .iter()
263                .position(|m| m.name == settings_model.name)
264            {
265                models_from_api[pos] = settings_model.clone();
266            } else {
267                models_from_api.push(settings_model.clone());
268            }
269        }
270
271        models_from_api
272            .into_iter()
273            .map(|model| self.create_language_model(model))
274            .collect()
275    }
276
277    fn is_authenticated(&self, cx: &App) -> bool {
278        self.state.read(cx).is_authenticated()
279    }
280
281    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
282        self.state.update(cx, |state, cx| state.authenticate(cx))
283    }
284
285    fn configuration_view(
286        &self,
287        _target_agent: language_model::ConfigurationViewTargetAgent,
288        window: &mut Window,
289        cx: &mut App,
290    ) -> AnyView {
291        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
292            .into()
293    }
294
295    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
296        self.state
297            .update(cx, |state, cx| state.set_api_key(None, cx))
298    }
299}
300
301pub struct OpenRouterLanguageModel {
302    id: LanguageModelId,
303    model: open_router::Model,
304    state: gpui::Entity<State>,
305    http_client: Arc<dyn HttpClient>,
306    request_limiter: RateLimiter,
307}
308
309impl OpenRouterLanguageModel {
310    fn stream_completion(
311        &self,
312        request: open_router::Request,
313        cx: &AsyncApp,
314    ) -> BoxFuture<
315        'static,
316        Result<
317            futures::stream::BoxStream<
318                'static,
319                Result<ResponseStreamEvent, open_router::OpenRouterError>,
320            >,
321            LanguageModelCompletionError,
322        >,
323    > {
324        let http_client = self.http_client.clone();
325        let api_key_and_url = self.state.read_with(cx, |state, cx| {
326            let api_url = OpenRouterLanguageModelProvider::api_url(cx);
327            let api_key = state.api_key_state.key(&api_url);
328            (api_key, api_url)
329        });
330        let (api_key, api_url) = match api_key_and_url {
331            Ok(api_key_and_url) => api_key_and_url,
332            Err(err) => {
333                return futures::future::ready(Err(LanguageModelCompletionError::Other(err)))
334                    .boxed();
335            }
336        };
337
338        async move {
339            let Some(api_key) = api_key else {
340                return Err(LanguageModelCompletionError::NoApiKey {
341                    provider: PROVIDER_NAME,
342                });
343            };
344            let request =
345                open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
346            request.await.map_err(Into::into)
347        }
348        .boxed()
349    }
350}
351
352impl LanguageModel for OpenRouterLanguageModel {
353    fn id(&self) -> LanguageModelId {
354        self.id.clone()
355    }
356
357    fn name(&self) -> LanguageModelName {
358        LanguageModelName::from(self.model.display_name().to_string())
359    }
360
361    fn provider_id(&self) -> LanguageModelProviderId {
362        PROVIDER_ID
363    }
364
365    fn provider_name(&self) -> LanguageModelProviderName {
366        PROVIDER_NAME
367    }
368
369    fn supports_tools(&self) -> bool {
370        self.model.supports_tool_calls()
371    }
372
373    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
374        let model_id = self.model.id().trim().to_lowercase();
375        if model_id.contains("gemini") || model_id.contains("grok") {
376            LanguageModelToolSchemaFormat::JsonSchemaSubset
377        } else {
378            LanguageModelToolSchemaFormat::JsonSchema
379        }
380    }
381
382    fn telemetry_id(&self) -> String {
383        format!("openrouter/{}", self.model.id())
384    }
385
386    fn max_token_count(&self) -> u64 {
387        self.model.max_token_count()
388    }
389
390    fn max_output_tokens(&self) -> Option<u64> {
391        self.model.max_output_tokens()
392    }
393
394    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
395        match choice {
396            LanguageModelToolChoice::Auto => true,
397            LanguageModelToolChoice::Any => true,
398            LanguageModelToolChoice::None => true,
399        }
400    }
401
402    fn supports_images(&self) -> bool {
403        self.model.supports_images.unwrap_or(false)
404    }
405
406    fn count_tokens(
407        &self,
408        request: LanguageModelRequest,
409        cx: &App,
410    ) -> BoxFuture<'static, Result<u64>> {
411        count_open_router_tokens(request, self.model.clone(), cx)
412    }
413
414    fn stream_completion(
415        &self,
416        request: LanguageModelRequest,
417        cx: &AsyncApp,
418    ) -> BoxFuture<
419        'static,
420        Result<
421            futures::stream::BoxStream<
422                'static,
423                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
424            >,
425            LanguageModelCompletionError,
426        >,
427    > {
428        let request = into_open_router(request, &self.model, self.max_output_tokens());
429        let request = self.stream_completion(request, cx);
430        let future = self.request_limiter.stream(async move {
431            let response = request.await?;
432            Ok(OpenRouterEventMapper::new().map_stream(response))
433        });
434        async move { Ok(future.await?.boxed()) }.boxed()
435    }
436}
437
438pub fn into_open_router(
439    request: LanguageModelRequest,
440    model: &Model,
441    max_output_tokens: Option<u64>,
442) -> open_router::Request {
443    let mut messages = Vec::new();
444    for message in request.messages {
445        for content in message.content {
446            match content {
447                MessageContent::Text(text) => add_message_content_part(
448                    open_router::MessagePart::Text { text },
449                    message.role,
450                    &mut messages,
451                ),
452                MessageContent::Thinking { .. } => {}
453                MessageContent::RedactedThinking(_) => {}
454                MessageContent::Image(image) => {
455                    add_message_content_part(
456                        open_router::MessagePart::Image {
457                            image_url: image.to_base64_url(),
458                        },
459                        message.role,
460                        &mut messages,
461                    );
462                }
463                MessageContent::ToolUse(tool_use) => {
464                    let tool_call = open_router::ToolCall {
465                        id: tool_use.id.to_string(),
466                        content: open_router::ToolCallContent::Function {
467                            function: open_router::FunctionContent {
468                                name: tool_use.name.to_string(),
469                                arguments: serde_json::to_string(&tool_use.input)
470                                    .unwrap_or_default(),
471                            },
472                        },
473                    };
474
475                    if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) =
476                        messages.last_mut()
477                    {
478                        tool_calls.push(tool_call);
479                    } else {
480                        messages.push(open_router::RequestMessage::Assistant {
481                            content: None,
482                            tool_calls: vec![tool_call],
483                        });
484                    }
485                }
486                MessageContent::ToolResult(tool_result) => {
487                    let content = match &tool_result.content {
488                        LanguageModelToolResultContent::Text(text) => {
489                            vec![open_router::MessagePart::Text {
490                                text: text.to_string(),
491                            }]
492                        }
493                        LanguageModelToolResultContent::Image(image) => {
494                            vec![open_router::MessagePart::Image {
495                                image_url: image.to_base64_url(),
496                            }]
497                        }
498                    };
499
500                    messages.push(open_router::RequestMessage::Tool {
501                        content: content.into(),
502                        tool_call_id: tool_result.tool_use_id.to_string(),
503                    });
504                }
505            }
506        }
507    }
508
509    open_router::Request {
510        model: model.id().into(),
511        messages,
512        stream: true,
513        stop: request.stop,
514        temperature: request.temperature.unwrap_or(0.4),
515        max_tokens: max_output_tokens,
516        parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
517            Some(false)
518        } else {
519            None
520        },
521        usage: open_router::RequestUsage { include: true },
522        reasoning: if request.thinking_allowed
523            && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
524        {
525            Some(open_router::Reasoning {
526                effort: None,
527                max_tokens: budget_tokens,
528                exclude: Some(false),
529                enabled: Some(true),
530            })
531        } else {
532            None
533        },
534        tools: request
535            .tools
536            .into_iter()
537            .map(|tool| open_router::ToolDefinition::Function {
538                function: open_router::FunctionDefinition {
539                    name: tool.name,
540                    description: Some(tool.description),
541                    parameters: Some(tool.input_schema),
542                },
543            })
544            .collect(),
545        tool_choice: request.tool_choice.map(|choice| match choice {
546            LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto,
547            LanguageModelToolChoice::Any => open_router::ToolChoice::Required,
548            LanguageModelToolChoice::None => open_router::ToolChoice::None,
549        }),
550        provider: model.provider.clone(),
551    }
552}
553
554fn add_message_content_part(
555    new_part: open_router::MessagePart,
556    role: Role,
557    messages: &mut Vec<open_router::RequestMessage>,
558) {
559    match (role, messages.last_mut()) {
560        (Role::User, Some(open_router::RequestMessage::User { content }))
561        | (Role::System, Some(open_router::RequestMessage::System { content })) => {
562            content.push_part(new_part);
563        }
564        (
565            Role::Assistant,
566            Some(open_router::RequestMessage::Assistant {
567                content: Some(content),
568                ..
569            }),
570        ) => {
571            content.push_part(new_part);
572        }
573        _ => {
574            messages.push(match role {
575                Role::User => open_router::RequestMessage::User {
576                    content: open_router::MessageContent::from(vec![new_part]),
577                },
578                Role::Assistant => open_router::RequestMessage::Assistant {
579                    content: Some(open_router::MessageContent::from(vec![new_part])),
580                    tool_calls: Vec::new(),
581                },
582                Role::System => open_router::RequestMessage::System {
583                    content: open_router::MessageContent::from(vec![new_part]),
584                },
585            });
586        }
587    }
588}
589
590pub struct OpenRouterEventMapper {
591    tool_calls_by_index: HashMap<usize, RawToolCall>,
592}
593
594impl OpenRouterEventMapper {
595    pub fn new() -> Self {
596        Self {
597            tool_calls_by_index: HashMap::default(),
598        }
599    }
600
601    pub fn map_stream(
602        mut self,
603        events: Pin<
604            Box<
605                dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
606            >,
607        >,
608    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
609    {
610        events.flat_map(move |event| {
611            futures::stream::iter(match event {
612                Ok(event) => self.map_event(event),
613                Err(error) => vec![Err(error.into())],
614            })
615        })
616    }
617
618    pub fn map_event(
619        &mut self,
620        event: ResponseStreamEvent,
621    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
622        let Some(choice) = event.choices.first() else {
623            return vec![Err(LanguageModelCompletionError::from(anyhow!(
624                "Response contained no choices"
625            )))];
626        };
627
628        let mut events = Vec::new();
629        if let Some(reasoning) = choice.delta.reasoning.clone() {
630            events.push(Ok(LanguageModelCompletionEvent::Thinking {
631                text: reasoning,
632                signature: None,
633            }));
634        }
635
636        if let Some(content) = choice.delta.content.clone() {
637            // OpenRouter send empty content string with the reasoning content
638            // This is a workaround for the OpenRouter API bug
639            if !content.is_empty() {
640                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
641            }
642        }
643
644        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
645            for tool_call in tool_calls {
646                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
647
648                if let Some(tool_id) = tool_call.id.clone() {
649                    entry.id = tool_id;
650                }
651
652                if let Some(function) = tool_call.function.as_ref() {
653                    if let Some(name) = function.name.clone() {
654                        entry.name = name;
655                    }
656
657                    if let Some(arguments) = function.arguments.clone() {
658                        entry.arguments.push_str(&arguments);
659                    }
660                }
661            }
662        }
663
664        if let Some(usage) = event.usage {
665            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
666                input_tokens: usage.prompt_tokens,
667                output_tokens: usage.completion_tokens,
668                cache_creation_input_tokens: 0,
669                cache_read_input_tokens: 0,
670            })));
671        }
672
673        match choice.finish_reason.as_deref() {
674            Some("stop") => {
675                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
676            }
677            Some("tool_calls") => {
678                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
679                    match serde_json::Value::from_str(&tool_call.arguments) {
680                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
681                            LanguageModelToolUse {
682                                id: tool_call.id.clone().into(),
683                                name: tool_call.name.as_str().into(),
684                                is_input_complete: true,
685                                input,
686                                raw_input: tool_call.arguments.clone(),
687                            },
688                        )),
689                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
690                            id: tool_call.id.clone().into(),
691                            tool_name: tool_call.name.as_str().into(),
692                            raw_input: tool_call.arguments.clone().into(),
693                            json_parse_error: error.to_string(),
694                        }),
695                    }
696                }));
697
698                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
699            }
700            Some(stop_reason) => {
701                log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
702                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
703            }
704            None => {}
705        }
706
707        events
708    }
709}
710
711#[derive(Default)]
712struct RawToolCall {
713    id: String,
714    name: String,
715    arguments: String,
716}
717
718pub fn count_open_router_tokens(
719    request: LanguageModelRequest,
720    _model: open_router::Model,
721    cx: &App,
722) -> BoxFuture<'static, Result<u64>> {
723    cx.background_spawn(async move {
724        let messages = request
725            .messages
726            .into_iter()
727            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
728                role: match message.role {
729                    Role::User => "user".into(),
730                    Role::Assistant => "assistant".into(),
731                    Role::System => "system".into(),
732                },
733                content: Some(message.string_contents()),
734                name: None,
735                function_call: None,
736            })
737            .collect::<Vec<_>>();
738
739        tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
740    })
741    .boxed()
742}
743
744struct ConfigurationView {
745    api_key_editor: Entity<Editor>,
746    state: gpui::Entity<State>,
747    load_credentials_task: Option<Task<()>>,
748}
749
750impl ConfigurationView {
751    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
752        let api_key_editor = cx.new(|cx| {
753            let mut editor = Editor::single_line(window, cx);
754            editor.set_placeholder_text(
755                "sk_or_000000000000000000000000000000000000000000000000",
756                window,
757                cx,
758            );
759            editor
760        });
761
762        cx.observe(&state, |_, _, cx| {
763            cx.notify();
764        })
765        .detach();
766
767        let load_credentials_task = Some(cx.spawn_in(window, {
768            let state = state.clone();
769            async move |this, cx| {
770                if let Some(task) = state
771                    .update(cx, |state, cx| state.authenticate(cx))
772                    .log_err()
773                {
774                    let _ = task.await;
775                }
776
777                this.update(cx, |this, cx| {
778                    this.load_credentials_task = None;
779                    cx.notify();
780                })
781                .log_err();
782            }
783        }));
784
785        Self {
786            api_key_editor,
787            state,
788            load_credentials_task,
789        }
790    }
791
792    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
793        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
794        if api_key.is_empty() {
795            return;
796        }
797
798        let state = self.state.clone();
799        cx.spawn_in(window, async move |_, cx| {
800            state
801                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
802                .await
803        })
804        .detach_and_log_err(cx);
805    }
806
807    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
808        self.api_key_editor
809            .update(cx, |editor, cx| editor.set_text("", window, cx));
810
811        let state = self.state.clone();
812        cx.spawn_in(window, async move |_, cx| {
813            state
814                .update(cx, |state, cx| state.set_api_key(None, cx))?
815                .await
816        })
817        .detach_and_log_err(cx);
818    }
819
820    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
821        let settings = ThemeSettings::get_global(cx);
822        let text_style = TextStyle {
823            color: cx.theme().colors().text,
824            font_family: settings.ui_font.family.clone(),
825            font_features: settings.ui_font.features.clone(),
826            font_fallbacks: settings.ui_font.fallbacks.clone(),
827            font_size: rems(0.875).into(),
828            font_weight: settings.ui_font.weight,
829            font_style: FontStyle::Normal,
830            line_height: relative(1.3),
831            white_space: WhiteSpace::Normal,
832            ..Default::default()
833        };
834        EditorElement::new(
835            &self.api_key_editor,
836            EditorStyle {
837                background: cx.theme().colors().editor_background,
838                local_player: cx.theme().players().local(),
839                text: text_style,
840                ..Default::default()
841            },
842        )
843    }
844
845    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
846        !self.state.read(cx).is_authenticated()
847    }
848}
849
850impl Render for ConfigurationView {
851    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
852        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
853
854        if self.load_credentials_task.is_some() {
855            div().child(Label::new("Loading credentials...")).into_any()
856        } else if self.should_render_editor(cx) {
857            v_flex()
858                .size_full()
859                .on_action(cx.listener(Self::save_api_key))
860                .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:"))
861                .child(
862                    List::new()
863                        .child(InstructionListItem::new(
864                            "Create an API key by visiting",
865                            Some("OpenRouter's console"),
866                            Some("https://openrouter.ai/keys"),
867                        ))
868                        .child(InstructionListItem::text_only(
869                            "Ensure your OpenRouter account has credits",
870                        ))
871                        .child(InstructionListItem::text_only(
872                            "Paste your API key below and hit enter to start using the assistant",
873                        )),
874                )
875                .child(
876                    h_flex()
877                        .w_full()
878                        .my_2()
879                        .px_2()
880                        .py_1()
881                        .bg(cx.theme().colors().editor_background)
882                        .border_1()
883                        .border_color(cx.theme().colors().border)
884                        .rounded_sm()
885                        .child(self.render_api_key_editor(cx)),
886                )
887                .child(
888                    Label::new(
889                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
890                    )
891                    .size(LabelSize::Small).color(Color::Muted),
892                )
893                .into_any()
894        } else {
895            h_flex()
896                .mt_1()
897                .p_1()
898                .justify_between()
899                .rounded_md()
900                .border_1()
901                .border_color(cx.theme().colors().border)
902                .bg(cx.theme().colors().background)
903                .child(
904                    h_flex()
905                        .gap_1()
906                        .child(Icon::new(IconName::Check).color(Color::Success))
907                        .child(Label::new(if env_var_set {
908                            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
909                        } else {
910                            "API key configured.".to_string()
911                        })),
912                )
913                .child(
914                    Button::new("reset-key", "Reset Key")
915                        .label_size(LabelSize::Small)
916                        .icon(Some(IconName::Trash))
917                        .icon_size(IconSize::Small)
918                        .icon_position(IconPosition::Start)
919                        .disabled(env_var_set)
920                        .when(env_var_set, |this| {
921                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
922                        })
923                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
924                )
925                .into_any()
926        }
927    }
928}