open_router.rs

  1use anyhow::{Result, anyhow};
  2use collections::HashMap;
  3use editor::{Editor, EditorElement, EditorStyle};
  4use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
  5use gpui::{
  6    AnyView, App, AsyncApp, Context, Entity, FontStyle, SharedString, Task, TextStyle, WhiteSpace,
  7};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
 14    LanguageModelToolUse, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
 15};
 16use open_router::{
 17    Model, ModelMode as OpenRouterModelMode, Provider, ResponseStreamEvent, list_models,
 18};
 19use schemars::JsonSchema;
 20use serde::{Deserialize, Serialize};
 21use settings::{Settings, SettingsStore};
 22use std::pin::Pin;
 23use std::str::FromStr as _;
 24use std::sync::{Arc, LazyLock};
 25use theme::ThemeSettings;
 26use ui::{Icon, IconName, List, Tooltip, prelude::*};
 27use util::ResultExt;
 28use zed_env_vars::{EnvVar, env_var};
 29
 30use crate::{api_key::ApiKeyState, ui::InstructionListItem};
 31
 32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter");
 33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter");
 34
 35const API_KEY_ENV_VAR_NAME: &str = "OPENROUTER_API_KEY";
 36static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct OpenRouterSettings {
 40    pub api_url: String,
 41    pub available_models: Vec<AvailableModel>,
 42}
 43
 44#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 45pub struct AvailableModel {
 46    pub name: String,
 47    pub display_name: Option<String>,
 48    pub max_tokens: u64,
 49    pub max_output_tokens: Option<u64>,
 50    pub max_completion_tokens: Option<u64>,
 51    pub supports_tools: Option<bool>,
 52    pub supports_images: Option<bool>,
 53    pub mode: Option<ModelMode>,
 54    pub provider: Option<Provider>,
 55}
 56
 57#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
 58#[serde(tag = "type", rename_all = "lowercase")]
 59pub enum ModelMode {
 60    #[default]
 61    Default,
 62    Thinking {
 63        budget_tokens: Option<u32>,
 64    },
 65}
 66
 67impl From<ModelMode> for OpenRouterModelMode {
 68    fn from(value: ModelMode) -> Self {
 69        match value {
 70            ModelMode::Default => OpenRouterModelMode::Default,
 71            ModelMode::Thinking { budget_tokens } => {
 72                OpenRouterModelMode::Thinking { budget_tokens }
 73            }
 74        }
 75    }
 76}
 77
 78impl From<OpenRouterModelMode> for ModelMode {
 79    fn from(value: OpenRouterModelMode) -> Self {
 80        match value {
 81            OpenRouterModelMode::Default => ModelMode::Default,
 82            OpenRouterModelMode::Thinking { budget_tokens } => {
 83                ModelMode::Thinking { budget_tokens }
 84            }
 85        }
 86    }
 87}
 88
 89pub struct OpenRouterLanguageModelProvider {
 90    http_client: Arc<dyn HttpClient>,
 91    state: gpui::Entity<State>,
 92}
 93
 94pub struct State {
 95    api_key_state: ApiKeyState,
 96    http_client: Arc<dyn HttpClient>,
 97    available_models: Vec<open_router::Model>,
 98    fetch_models_task: Option<Task<Result<(), LanguageModelCompletionError>>>,
 99}
100
101impl State {
102    fn is_authenticated(&self) -> bool {
103        self.api_key_state.has_key()
104    }
105
106    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
107        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
108        self.api_key_state
109            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
110    }
111
112    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
113        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
114        let task = self.api_key_state.load_if_needed(
115            api_url,
116            &API_KEY_ENV_VAR,
117            |this| &mut this.api_key_state,
118            cx,
119        );
120
121        cx.spawn(async move |this, cx| {
122            let result = task.await;
123            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
124                .ok();
125            result
126        })
127    }
128
129    fn fetch_models(
130        &mut self,
131        cx: &mut Context<Self>,
132    ) -> Task<Result<(), LanguageModelCompletionError>> {
133        let http_client = self.http_client.clone();
134        let api_url = OpenRouterLanguageModelProvider::api_url(cx);
135        let Some(api_key) = self.api_key_state.key(&api_url) else {
136            return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
137                provider: PROVIDER_NAME,
138            }));
139        };
140        cx.spawn(async move |this, cx| {
141            let models = list_models(http_client.as_ref(), &api_url, &api_key)
142                .await
143                .map_err(|e| {
144                    LanguageModelCompletionError::Other(anyhow::anyhow!(
145                        "OpenRouter error: {:?}",
146                        e
147                    ))
148                })?;
149
150            this.update(cx, |this, cx| {
151                this.available_models = models;
152                cx.notify();
153            })
154            .map_err(|e| LanguageModelCompletionError::Other(e))?;
155
156            Ok(())
157        })
158    }
159
160    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
161        if self.is_authenticated() {
162            let task = self.fetch_models(cx);
163            self.fetch_models_task.replace(task);
164        } else {
165            self.available_models = Vec::new();
166        }
167    }
168}
169
170impl OpenRouterLanguageModelProvider {
171    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
172        let state = cx.new(|cx| {
173            cx.observe_global::<SettingsStore>({
174                let mut last_settings = OpenRouterLanguageModelProvider::settings(cx).clone();
175                move |this: &mut State, cx| {
176                    let current_settings = OpenRouterLanguageModelProvider::settings(cx);
177                    let settings_changed = current_settings != &last_settings;
178                    if settings_changed {
179                        last_settings = current_settings.clone();
180                        this.authenticate(cx).detach();
181                        cx.notify();
182                    }
183                }
184            })
185            .detach();
186            State {
187                api_key_state: ApiKeyState::new(Self::api_url(cx)),
188                http_client: http_client.clone(),
189                available_models: Vec::new(),
190                fetch_models_task: None,
191            }
192        });
193
194        Self { http_client, state }
195    }
196
197    fn settings(cx: &App) -> &OpenRouterSettings {
198        &crate::AllLanguageModelSettings::get_global(cx).open_router
199    }
200
201    fn api_url(cx: &App) -> SharedString {
202        SharedString::new(Self::settings(cx).api_url.as_str())
203    }
204
205    fn create_language_model(&self, model: open_router::Model) -> Arc<dyn LanguageModel> {
206        Arc::new(OpenRouterLanguageModel {
207            id: LanguageModelId::from(model.id().to_string()),
208            model,
209            state: self.state.clone(),
210            http_client: self.http_client.clone(),
211            request_limiter: RateLimiter::new(4),
212        })
213    }
214}
215
216impl LanguageModelProviderState for OpenRouterLanguageModelProvider {
217    type ObservableEntity = State;
218
219    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
220        Some(self.state.clone())
221    }
222}
223
224impl LanguageModelProvider for OpenRouterLanguageModelProvider {
225    fn id(&self) -> LanguageModelProviderId {
226        PROVIDER_ID
227    }
228
229    fn name(&self) -> LanguageModelProviderName {
230        PROVIDER_NAME
231    }
232
233    fn icon(&self) -> IconName {
234        IconName::AiOpenRouter
235    }
236
237    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
238        Some(self.create_language_model(open_router::Model::default()))
239    }
240
241    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
242        Some(self.create_language_model(open_router::Model::default_fast()))
243    }
244
245    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
246        let mut models_from_api = self.state.read(cx).available_models.clone();
247        let mut settings_models = Vec::new();
248
249        for model in &Self::settings(cx).available_models {
250            settings_models.push(open_router::Model {
251                name: model.name.clone(),
252                display_name: model.display_name.clone(),
253                max_tokens: model.max_tokens,
254                supports_tools: model.supports_tools,
255                supports_images: model.supports_images,
256                mode: model.mode.clone().unwrap_or_default().into(),
257                provider: model.provider.clone(),
258            });
259        }
260
261        for settings_model in &settings_models {
262            if let Some(pos) = models_from_api
263                .iter()
264                .position(|m| m.name == settings_model.name)
265            {
266                models_from_api[pos] = settings_model.clone();
267            } else {
268                models_from_api.push(settings_model.clone());
269            }
270        }
271
272        models_from_api
273            .into_iter()
274            .map(|model| self.create_language_model(model))
275            .collect()
276    }
277
278    fn is_authenticated(&self, cx: &App) -> bool {
279        self.state.read(cx).is_authenticated()
280    }
281
282    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
283        self.state.update(cx, |state, cx| state.authenticate(cx))
284    }
285
286    fn configuration_view(
287        &self,
288        _target_agent: language_model::ConfigurationViewTargetAgent,
289        window: &mut Window,
290        cx: &mut App,
291    ) -> AnyView {
292        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
293            .into()
294    }
295
296    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
297        self.state
298            .update(cx, |state, cx| state.set_api_key(None, cx))
299    }
300}
301
302pub struct OpenRouterLanguageModel {
303    id: LanguageModelId,
304    model: open_router::Model,
305    state: gpui::Entity<State>,
306    http_client: Arc<dyn HttpClient>,
307    request_limiter: RateLimiter,
308}
309
310impl OpenRouterLanguageModel {
311    fn stream_completion(
312        &self,
313        request: open_router::Request,
314        cx: &AsyncApp,
315    ) -> BoxFuture<
316        'static,
317        Result<
318            futures::stream::BoxStream<
319                'static,
320                Result<ResponseStreamEvent, open_router::OpenRouterError>,
321            >,
322            LanguageModelCompletionError,
323        >,
324    > {
325        let http_client = self.http_client.clone();
326        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
327            let api_url = OpenRouterLanguageModelProvider::api_url(cx);
328            (state.api_key_state.key(&api_url), api_url)
329        }) else {
330            return future::ready(Err(anyhow!("App state dropped").into())).boxed();
331        };
332
333        async move {
334            let Some(api_key) = api_key else {
335                return Err(LanguageModelCompletionError::NoApiKey {
336                    provider: PROVIDER_NAME,
337                });
338            };
339            let request =
340                open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
341            request.await.map_err(Into::into)
342        }
343        .boxed()
344    }
345}
346
347impl LanguageModel for OpenRouterLanguageModel {
348    fn id(&self) -> LanguageModelId {
349        self.id.clone()
350    }
351
352    fn name(&self) -> LanguageModelName {
353        LanguageModelName::from(self.model.display_name().to_string())
354    }
355
356    fn provider_id(&self) -> LanguageModelProviderId {
357        PROVIDER_ID
358    }
359
360    fn provider_name(&self) -> LanguageModelProviderName {
361        PROVIDER_NAME
362    }
363
364    fn supports_tools(&self) -> bool {
365        self.model.supports_tool_calls()
366    }
367
368    fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
369        let model_id = self.model.id().trim().to_lowercase();
370        if model_id.contains("gemini") || model_id.contains("grok") {
371            LanguageModelToolSchemaFormat::JsonSchemaSubset
372        } else {
373            LanguageModelToolSchemaFormat::JsonSchema
374        }
375    }
376
377    fn telemetry_id(&self) -> String {
378        format!("openrouter/{}", self.model.id())
379    }
380
381    fn max_token_count(&self) -> u64 {
382        self.model.max_token_count()
383    }
384
385    fn max_output_tokens(&self) -> Option<u64> {
386        self.model.max_output_tokens()
387    }
388
389    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
390        match choice {
391            LanguageModelToolChoice::Auto => true,
392            LanguageModelToolChoice::Any => true,
393            LanguageModelToolChoice::None => true,
394        }
395    }
396
397    fn supports_images(&self) -> bool {
398        self.model.supports_images.unwrap_or(false)
399    }
400
401    fn count_tokens(
402        &self,
403        request: LanguageModelRequest,
404        cx: &App,
405    ) -> BoxFuture<'static, Result<u64>> {
406        count_open_router_tokens(request, self.model.clone(), cx)
407    }
408
409    fn stream_completion(
410        &self,
411        request: LanguageModelRequest,
412        cx: &AsyncApp,
413    ) -> BoxFuture<
414        'static,
415        Result<
416            futures::stream::BoxStream<
417                'static,
418                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
419            >,
420            LanguageModelCompletionError,
421        >,
422    > {
423        let request = into_open_router(request, &self.model, self.max_output_tokens());
424        let request = self.stream_completion(request, cx);
425        let future = self.request_limiter.stream(async move {
426            let response = request.await?;
427            Ok(OpenRouterEventMapper::new().map_stream(response))
428        });
429        async move { Ok(future.await?.boxed()) }.boxed()
430    }
431}
432
433pub fn into_open_router(
434    request: LanguageModelRequest,
435    model: &Model,
436    max_output_tokens: Option<u64>,
437) -> open_router::Request {
438    let mut messages = Vec::new();
439    for message in request.messages {
440        for content in message.content {
441            match content {
442                MessageContent::Text(text) => add_message_content_part(
443                    open_router::MessagePart::Text { text },
444                    message.role,
445                    &mut messages,
446                ),
447                MessageContent::Thinking { .. } => {}
448                MessageContent::RedactedThinking(_) => {}
449                MessageContent::Image(image) => {
450                    add_message_content_part(
451                        open_router::MessagePart::Image {
452                            image_url: image.to_base64_url(),
453                        },
454                        message.role,
455                        &mut messages,
456                    );
457                }
458                MessageContent::ToolUse(tool_use) => {
459                    let tool_call = open_router::ToolCall {
460                        id: tool_use.id.to_string(),
461                        content: open_router::ToolCallContent::Function {
462                            function: open_router::FunctionContent {
463                                name: tool_use.name.to_string(),
464                                arguments: serde_json::to_string(&tool_use.input)
465                                    .unwrap_or_default(),
466                            },
467                        },
468                    };
469
470                    if let Some(open_router::RequestMessage::Assistant { tool_calls, .. }) =
471                        messages.last_mut()
472                    {
473                        tool_calls.push(tool_call);
474                    } else {
475                        messages.push(open_router::RequestMessage::Assistant {
476                            content: None,
477                            tool_calls: vec![tool_call],
478                        });
479                    }
480                }
481                MessageContent::ToolResult(tool_result) => {
482                    let content = match &tool_result.content {
483                        LanguageModelToolResultContent::Text(text) => {
484                            vec![open_router::MessagePart::Text {
485                                text: text.to_string(),
486                            }]
487                        }
488                        LanguageModelToolResultContent::Image(image) => {
489                            vec![open_router::MessagePart::Image {
490                                image_url: image.to_base64_url(),
491                            }]
492                        }
493                    };
494
495                    messages.push(open_router::RequestMessage::Tool {
496                        content: content.into(),
497                        tool_call_id: tool_result.tool_use_id.to_string(),
498                    });
499                }
500            }
501        }
502    }
503
504    open_router::Request {
505        model: model.id().into(),
506        messages,
507        stream: true,
508        stop: request.stop,
509        temperature: request.temperature.unwrap_or(0.4),
510        max_tokens: max_output_tokens,
511        parallel_tool_calls: if model.supports_parallel_tool_calls() && !request.tools.is_empty() {
512            Some(false)
513        } else {
514            None
515        },
516        usage: open_router::RequestUsage { include: true },
517        reasoning: if request.thinking_allowed
518            && let OpenRouterModelMode::Thinking { budget_tokens } = model.mode
519        {
520            Some(open_router::Reasoning {
521                effort: None,
522                max_tokens: budget_tokens,
523                exclude: Some(false),
524                enabled: Some(true),
525            })
526        } else {
527            None
528        },
529        tools: request
530            .tools
531            .into_iter()
532            .map(|tool| open_router::ToolDefinition::Function {
533                function: open_router::FunctionDefinition {
534                    name: tool.name,
535                    description: Some(tool.description),
536                    parameters: Some(tool.input_schema),
537                },
538            })
539            .collect(),
540        tool_choice: request.tool_choice.map(|choice| match choice {
541            LanguageModelToolChoice::Auto => open_router::ToolChoice::Auto,
542            LanguageModelToolChoice::Any => open_router::ToolChoice::Required,
543            LanguageModelToolChoice::None => open_router::ToolChoice::None,
544        }),
545        provider: model.provider.clone(),
546    }
547}
548
549fn add_message_content_part(
550    new_part: open_router::MessagePart,
551    role: Role,
552    messages: &mut Vec<open_router::RequestMessage>,
553) {
554    match (role, messages.last_mut()) {
555        (Role::User, Some(open_router::RequestMessage::User { content }))
556        | (Role::System, Some(open_router::RequestMessage::System { content })) => {
557            content.push_part(new_part);
558        }
559        (
560            Role::Assistant,
561            Some(open_router::RequestMessage::Assistant {
562                content: Some(content),
563                ..
564            }),
565        ) => {
566            content.push_part(new_part);
567        }
568        _ => {
569            messages.push(match role {
570                Role::User => open_router::RequestMessage::User {
571                    content: open_router::MessageContent::from(vec![new_part]),
572                },
573                Role::Assistant => open_router::RequestMessage::Assistant {
574                    content: Some(open_router::MessageContent::from(vec![new_part])),
575                    tool_calls: Vec::new(),
576                },
577                Role::System => open_router::RequestMessage::System {
578                    content: open_router::MessageContent::from(vec![new_part]),
579                },
580            });
581        }
582    }
583}
584
585pub struct OpenRouterEventMapper {
586    tool_calls_by_index: HashMap<usize, RawToolCall>,
587}
588
589impl OpenRouterEventMapper {
590    pub fn new() -> Self {
591        Self {
592            tool_calls_by_index: HashMap::default(),
593        }
594    }
595
596    pub fn map_stream(
597        mut self,
598        events: Pin<
599            Box<
600                dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
601            >,
602        >,
603    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
604    {
605        events.flat_map(move |event| {
606            futures::stream::iter(match event {
607                Ok(event) => self.map_event(event),
608                Err(error) => vec![Err(error.into())],
609            })
610        })
611    }
612
613    pub fn map_event(
614        &mut self,
615        event: ResponseStreamEvent,
616    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
617        let Some(choice) = event.choices.first() else {
618            return vec![Err(LanguageModelCompletionError::from(anyhow!(
619                "Response contained no choices"
620            )))];
621        };
622
623        let mut events = Vec::new();
624        if let Some(reasoning) = choice.delta.reasoning.clone() {
625            events.push(Ok(LanguageModelCompletionEvent::Thinking {
626                text: reasoning,
627                signature: None,
628            }));
629        }
630
631        if let Some(content) = choice.delta.content.clone() {
632            // OpenRouter send empty content string with the reasoning content
633            // This is a workaround for the OpenRouter API bug
634            if !content.is_empty() {
635                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
636            }
637        }
638
639        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
640            for tool_call in tool_calls {
641                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
642
643                if let Some(tool_id) = tool_call.id.clone() {
644                    entry.id = tool_id;
645                }
646
647                if let Some(function) = tool_call.function.as_ref() {
648                    if let Some(name) = function.name.clone() {
649                        entry.name = name;
650                    }
651
652                    if let Some(arguments) = function.arguments.clone() {
653                        entry.arguments.push_str(&arguments);
654                    }
655                }
656            }
657        }
658
659        if let Some(usage) = event.usage {
660            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
661                input_tokens: usage.prompt_tokens,
662                output_tokens: usage.completion_tokens,
663                cache_creation_input_tokens: 0,
664                cache_read_input_tokens: 0,
665            })));
666        }
667
668        match choice.finish_reason.as_deref() {
669            Some("stop") => {
670                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
671            }
672            Some("tool_calls") => {
673                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
674                    match serde_json::Value::from_str(&tool_call.arguments) {
675                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
676                            LanguageModelToolUse {
677                                id: tool_call.id.clone().into(),
678                                name: tool_call.name.as_str().into(),
679                                is_input_complete: true,
680                                input,
681                                raw_input: tool_call.arguments.clone(),
682                            },
683                        )),
684                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
685                            id: tool_call.id.clone().into(),
686                            tool_name: tool_call.name.as_str().into(),
687                            raw_input: tool_call.arguments.clone().into(),
688                            json_parse_error: error.to_string(),
689                        }),
690                    }
691                }));
692
693                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
694            }
695            Some(stop_reason) => {
696                log::error!("Unexpected OpenRouter stop_reason: {stop_reason:?}",);
697                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
698            }
699            None => {}
700        }
701
702        events
703    }
704}
705
706#[derive(Default)]
707struct RawToolCall {
708    id: String,
709    name: String,
710    arguments: String,
711}
712
713pub fn count_open_router_tokens(
714    request: LanguageModelRequest,
715    _model: open_router::Model,
716    cx: &App,
717) -> BoxFuture<'static, Result<u64>> {
718    cx.background_spawn(async move {
719        let messages = request
720            .messages
721            .into_iter()
722            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
723                role: match message.role {
724                    Role::User => "user".into(),
725                    Role::Assistant => "assistant".into(),
726                    Role::System => "system".into(),
727                },
728                content: Some(message.string_contents()),
729                name: None,
730                function_call: None,
731            })
732            .collect::<Vec<_>>();
733
734        tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages).map(|tokens| tokens as u64)
735    })
736    .boxed()
737}
738
739struct ConfigurationView {
740    api_key_editor: Entity<Editor>,
741    state: gpui::Entity<State>,
742    load_credentials_task: Option<Task<()>>,
743}
744
745impl ConfigurationView {
746    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
747        let api_key_editor = cx.new(|cx| {
748            let mut editor = Editor::single_line(window, cx);
749            editor.set_placeholder_text(
750                "sk_or_000000000000000000000000000000000000000000000000",
751                window,
752                cx,
753            );
754            editor
755        });
756
757        cx.observe(&state, |_, _, cx| {
758            cx.notify();
759        })
760        .detach();
761
762        let load_credentials_task = Some(cx.spawn_in(window, {
763            let state = state.clone();
764            async move |this, cx| {
765                if let Some(task) = state
766                    .update(cx, |state, cx| state.authenticate(cx))
767                    .log_err()
768                {
769                    let _ = task.await;
770                }
771
772                this.update(cx, |this, cx| {
773                    this.load_credentials_task = None;
774                    cx.notify();
775                })
776                .log_err();
777            }
778        }));
779
780        Self {
781            api_key_editor,
782            state,
783            load_credentials_task,
784        }
785    }
786
787    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
788        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
789        if api_key.is_empty() {
790            return;
791        }
792
793        // url changes can cause the editor to be displayed again
794        self.api_key_editor
795            .update(cx, |editor, cx| editor.set_text("", window, cx));
796
797        let state = self.state.clone();
798        cx.spawn_in(window, async move |_, cx| {
799            state
800                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
801                .await
802        })
803        .detach_and_log_err(cx);
804    }
805
806    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
807        self.api_key_editor
808            .update(cx, |editor, cx| editor.set_text("", window, cx));
809
810        let state = self.state.clone();
811        cx.spawn_in(window, async move |_, cx| {
812            state
813                .update(cx, |state, cx| state.set_api_key(None, cx))?
814                .await
815        })
816        .detach_and_log_err(cx);
817    }
818
819    fn render_api_key_editor(&self, cx: &mut Context<Self>) -> impl IntoElement {
820        let settings = ThemeSettings::get_global(cx);
821        let text_style = TextStyle {
822            color: cx.theme().colors().text,
823            font_family: settings.ui_font.family.clone(),
824            font_features: settings.ui_font.features.clone(),
825            font_fallbacks: settings.ui_font.fallbacks.clone(),
826            font_size: rems(0.875).into(),
827            font_weight: settings.ui_font.weight,
828            font_style: FontStyle::Normal,
829            line_height: relative(1.3),
830            white_space: WhiteSpace::Normal,
831            ..Default::default()
832        };
833        EditorElement::new(
834            &self.api_key_editor,
835            EditorStyle {
836                background: cx.theme().colors().editor_background,
837                local_player: cx.theme().players().local(),
838                text: text_style,
839                ..Default::default()
840            },
841        )
842    }
843
844    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
845        !self.state.read(cx).is_authenticated()
846    }
847}
848
849impl Render for ConfigurationView {
850    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
851        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
852
853        if self.load_credentials_task.is_some() {
854            div().child(Label::new("Loading credentials...")).into_any()
855        } else if self.should_render_editor(cx) {
856            v_flex()
857                .size_full()
858                .on_action(cx.listener(Self::save_api_key))
859                .child(Label::new("To use Zed's agent with OpenRouter, you need to add an API key. Follow these steps:"))
860                .child(
861                    List::new()
862                        .child(InstructionListItem::new(
863                            "Create an API key by visiting",
864                            Some("OpenRouter's console"),
865                            Some("https://openrouter.ai/keys"),
866                        ))
867                        .child(InstructionListItem::text_only(
868                            "Ensure your OpenRouter account has credits",
869                        ))
870                        .child(InstructionListItem::text_only(
871                            "Paste your API key below and hit enter to start using the assistant",
872                        )),
873                )
874                .child(
875                    h_flex()
876                        .w_full()
877                        .my_2()
878                        .px_2()
879                        .py_1()
880                        .bg(cx.theme().colors().editor_background)
881                        .border_1()
882                        .border_color(cx.theme().colors().border)
883                        .rounded_sm()
884                        .child(self.render_api_key_editor(cx)),
885                )
886                .child(
887                    Label::new(
888                        format!("You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."),
889                    )
890                    .size(LabelSize::Small).color(Color::Muted),
891                )
892                .into_any()
893        } else {
894            h_flex()
895                .mt_1()
896                .p_1()
897                .justify_between()
898                .rounded_md()
899                .border_1()
900                .border_color(cx.theme().colors().border)
901                .bg(cx.theme().colors().background)
902                .child(
903                    h_flex()
904                        .gap_1()
905                        .child(Icon::new(IconName::Check).color(Color::Success))
906                        .child(Label::new(if env_var_set {
907                            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
908                        } else {
909                            "API key configured.".to_string()
910                        })),
911                )
912                .child(
913                    Button::new("reset-key", "Reset Key")
914                        .label_size(LabelSize::Small)
915                        .icon(Some(IconName::Trash))
916                        .icon_size(IconSize::Small)
917                        .icon_position(IconPosition::Start)
918                        .disabled(env_var_set)
919                        .when(env_var_set, |this| {
920                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
921                        })
922                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
923                )
924                .into_any()
925        }
926    }
927}