open_ai.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::{BTreeMap, HashMap};
   3use credentials_provider::CredentialsProvider;
   4
   5use fs::Fs;
   6use futures::Stream;
   7use futures::{FutureExt, StreamExt, future::BoxFuture};
   8use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
   9use http_client::HttpClient;
  10use language_model::{
  11    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  12    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  13    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  14    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  15    RateLimiter, Role, StopReason, TokenUsage,
  16};
  17use menu;
  18use open_ai::{ImageUrl, Model, ResponseStreamEvent, stream_completion};
  19use schemars::JsonSchema;
  20use serde::{Deserialize, Serialize};
  21use settings::{Settings, SettingsStore, update_settings_file};
  22use std::pin::Pin;
  23use std::str::FromStr as _;
  24use std::sync::Arc;
  25use strum::IntoEnumIterator;
  26
  27use ui::{ElevationIndex, List, Tooltip, prelude::*};
  28use ui_input::SingleLineInput;
  29use util::ResultExt;
  30
  31use crate::OpenAiSettingsContent;
  32use crate::{AllLanguageModelSettings, ui::InstructionListItem};
  33
  34const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
  35const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
  36
  37#[derive(Default, Clone, Debug, PartialEq)]
  38pub struct OpenAiSettings {
  39    pub api_url: String,
  40    pub available_models: Vec<AvailableModel>,
  41}
  42
  43#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  44pub struct AvailableModel {
  45    pub name: String,
  46    pub display_name: Option<String>,
  47    pub max_tokens: u64,
  48    pub max_output_tokens: Option<u64>,
  49    pub max_completion_tokens: Option<u64>,
  50}
  51
  52pub struct OpenAiLanguageModelProvider {
  53    http_client: Arc<dyn HttpClient>,
  54    state: gpui::Entity<State>,
  55}
  56
  57pub struct State {
  58    api_key: Option<String>,
  59    api_key_from_env: bool,
  60    _subscription: Subscription,
  61}
  62
  63const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
  64
  65impl State {
  66    //
  67    fn is_authenticated(&self) -> bool {
  68        self.api_key.is_some()
  69    }
  70
  71    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
  72        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  73        let api_url = AllLanguageModelSettings::get_global(cx)
  74            .openai
  75            .api_url
  76            .clone();
  77        cx.spawn(async move |this, cx| {
  78            credentials_provider
  79                .delete_credentials(&api_url, &cx)
  80                .await
  81                .log_err();
  82            this.update(cx, |this, cx| {
  83                this.api_key = None;
  84                this.api_key_from_env = false;
  85                cx.notify();
  86            })
  87        })
  88    }
  89
  90    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
  91        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  92        let api_url = AllLanguageModelSettings::get_global(cx)
  93            .openai
  94            .api_url
  95            .clone();
  96        cx.spawn(async move |this, cx| {
  97            credentials_provider
  98                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
  99                .await
 100                .log_err();
 101            this.update(cx, |this, cx| {
 102                this.api_key = Some(api_key);
 103                cx.notify();
 104            })
 105        })
 106    }
 107
 108    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 109        if self.is_authenticated() {
 110            return Task::ready(Ok(()));
 111        }
 112
 113        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 114        let api_url = AllLanguageModelSettings::get_global(cx)
 115            .openai
 116            .api_url
 117            .clone();
 118        cx.spawn(async move |this, cx| {
 119            let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
 120                (api_key, true)
 121            } else {
 122                let (_, api_key) = credentials_provider
 123                    .read_credentials(&api_url, &cx)
 124                    .await?
 125                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 126                (
 127                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 128                    false,
 129                )
 130            };
 131            this.update(cx, |this, cx| {
 132                this.api_key = Some(api_key);
 133                this.api_key_from_env = from_env;
 134                cx.notify();
 135            })?;
 136
 137            Ok(())
 138        })
 139    }
 140}
 141
 142impl OpenAiLanguageModelProvider {
 143    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 144        let state = cx.new(|cx| State {
 145            api_key: None,
 146            api_key_from_env: false,
 147            _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
 148                cx.notify();
 149            }),
 150        });
 151
 152        Self { http_client, state }
 153    }
 154
 155    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
 156        Arc::new(OpenAiLanguageModel {
 157            id: LanguageModelId::from(model.id().to_string()),
 158            model,
 159            state: self.state.clone(),
 160            http_client: self.http_client.clone(),
 161            request_limiter: RateLimiter::new(4),
 162        })
 163    }
 164}
 165
 166impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 167    type ObservableEntity = State;
 168
 169    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 170        Some(self.state.clone())
 171    }
 172}
 173
 174impl LanguageModelProvider for OpenAiLanguageModelProvider {
 175    fn id(&self) -> LanguageModelProviderId {
 176        PROVIDER_ID
 177    }
 178
 179    fn name(&self) -> LanguageModelProviderName {
 180        PROVIDER_NAME
 181    }
 182
 183    fn icon(&self) -> IconName {
 184        IconName::AiOpenAi
 185    }
 186
 187    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 188        Some(self.create_language_model(open_ai::Model::default()))
 189    }
 190
 191    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 192        Some(self.create_language_model(open_ai::Model::default_fast()))
 193    }
 194
 195    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 196        let mut models = BTreeMap::default();
 197
 198        // Add base models from open_ai::Model::iter()
 199        for model in open_ai::Model::iter() {
 200            if !matches!(model, open_ai::Model::Custom { .. }) {
 201                models.insert(model.id().to_string(), model);
 202            }
 203        }
 204
 205        // Override with available models from settings
 206        for model in &AllLanguageModelSettings::get_global(cx)
 207            .openai
 208            .available_models
 209        {
 210            models.insert(
 211                model.name.clone(),
 212                open_ai::Model::Custom {
 213                    name: model.name.clone(),
 214                    display_name: model.display_name.clone(),
 215                    max_tokens: model.max_tokens,
 216                    max_output_tokens: model.max_output_tokens,
 217                    max_completion_tokens: model.max_completion_tokens,
 218                },
 219            );
 220        }
 221
 222        models
 223            .into_values()
 224            .map(|model| self.create_language_model(model))
 225            .collect()
 226    }
 227
 228    fn is_authenticated(&self, cx: &App) -> bool {
 229        self.state.read(cx).is_authenticated()
 230    }
 231
 232    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 233        self.state.update(cx, |state, cx| state.authenticate(cx))
 234    }
 235
 236    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
 237        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 238            .into()
 239    }
 240
 241    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 242        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 243    }
 244}
 245
 246pub struct OpenAiLanguageModel {
 247    id: LanguageModelId,
 248    model: open_ai::Model,
 249    state: gpui::Entity<State>,
 250    http_client: Arc<dyn HttpClient>,
 251    request_limiter: RateLimiter,
 252}
 253
 254impl OpenAiLanguageModel {
 255    fn stream_completion(
 256        &self,
 257        request: open_ai::Request,
 258        cx: &AsyncApp,
 259    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
 260    {
 261        let http_client = self.http_client.clone();
 262        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 263            let settings = &AllLanguageModelSettings::get_global(cx).openai;
 264            (state.api_key.clone(), settings.api_url.clone())
 265        }) else {
 266            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
 267        };
 268
 269        let future = self.request_limiter.stream(async move {
 270            let Some(api_key) = api_key else {
 271                return Err(LanguageModelCompletionError::NoApiKey {
 272                    provider: PROVIDER_NAME,
 273                });
 274            };
 275            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 276            let response = request.await?;
 277            Ok(response)
 278        });
 279
 280        async move { Ok(future.await?.boxed()) }.boxed()
 281    }
 282}
 283
 284impl LanguageModel for OpenAiLanguageModel {
 285    fn id(&self) -> LanguageModelId {
 286        self.id.clone()
 287    }
 288
 289    fn name(&self) -> LanguageModelName {
 290        LanguageModelName::from(self.model.display_name().to_string())
 291    }
 292
 293    fn provider_id(&self) -> LanguageModelProviderId {
 294        PROVIDER_ID
 295    }
 296
 297    fn provider_name(&self) -> LanguageModelProviderName {
 298        PROVIDER_NAME
 299    }
 300
 301    fn supports_tools(&self) -> bool {
 302        true
 303    }
 304
 305    fn supports_images(&self) -> bool {
 306        false
 307    }
 308
 309    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 310        match choice {
 311            LanguageModelToolChoice::Auto => true,
 312            LanguageModelToolChoice::Any => true,
 313            LanguageModelToolChoice::None => true,
 314        }
 315    }
 316
 317    fn telemetry_id(&self) -> String {
 318        format!("openai/{}", self.model.id())
 319    }
 320
 321    fn max_token_count(&self) -> u64 {
 322        self.model.max_token_count()
 323    }
 324
 325    fn max_output_tokens(&self) -> Option<u64> {
 326        self.model.max_output_tokens()
 327    }
 328
 329    fn count_tokens(
 330        &self,
 331        request: LanguageModelRequest,
 332        cx: &App,
 333    ) -> BoxFuture<'static, Result<u64>> {
 334        count_open_ai_tokens(request, self.model.clone(), cx)
 335    }
 336
 337    fn stream_completion(
 338        &self,
 339        request: LanguageModelRequest,
 340        cx: &AsyncApp,
 341    ) -> BoxFuture<
 342        'static,
 343        Result<
 344            futures::stream::BoxStream<
 345                'static,
 346                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 347            >,
 348            LanguageModelCompletionError,
 349        >,
 350    > {
 351        let request = into_open_ai(
 352            request,
 353            self.model.id(),
 354            self.model.supports_parallel_tool_calls(),
 355            self.max_output_tokens(),
 356        );
 357        let completions = self.stream_completion(request, cx);
 358        async move {
 359            let mapper = OpenAiEventMapper::new();
 360            Ok(mapper.map_stream(completions.await?).boxed())
 361        }
 362        .boxed()
 363    }
 364}
 365
 366pub fn into_open_ai(
 367    request: LanguageModelRequest,
 368    model_id: &str,
 369    supports_parallel_tool_calls: bool,
 370    max_output_tokens: Option<u64>,
 371) -> open_ai::Request {
 372    let stream = !model_id.starts_with("o1-");
 373
 374    let mut messages = Vec::new();
 375    for message in request.messages {
 376        for content in message.content {
 377            match content {
 378                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
 379                    add_message_content_part(
 380                        open_ai::MessagePart::Text { text: text },
 381                        message.role,
 382                        &mut messages,
 383                    )
 384                }
 385                MessageContent::RedactedThinking(_) => {}
 386                MessageContent::Image(image) => {
 387                    add_message_content_part(
 388                        open_ai::MessagePart::Image {
 389                            image_url: ImageUrl {
 390                                url: image.to_base64_url(),
 391                                detail: None,
 392                            },
 393                        },
 394                        message.role,
 395                        &mut messages,
 396                    );
 397                }
 398                MessageContent::ToolUse(tool_use) => {
 399                    let tool_call = open_ai::ToolCall {
 400                        id: tool_use.id.to_string(),
 401                        content: open_ai::ToolCallContent::Function {
 402                            function: open_ai::FunctionContent {
 403                                name: tool_use.name.to_string(),
 404                                arguments: serde_json::to_string(&tool_use.input)
 405                                    .unwrap_or_default(),
 406                            },
 407                        },
 408                    };
 409
 410                    if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
 411                        messages.last_mut()
 412                    {
 413                        tool_calls.push(tool_call);
 414                    } else {
 415                        messages.push(open_ai::RequestMessage::Assistant {
 416                            content: None,
 417                            tool_calls: vec![tool_call],
 418                        });
 419                    }
 420                }
 421                MessageContent::ToolResult(tool_result) => {
 422                    let content = match &tool_result.content {
 423                        LanguageModelToolResultContent::Text(text) => {
 424                            vec![open_ai::MessagePart::Text {
 425                                text: text.to_string(),
 426                            }]
 427                        }
 428                        LanguageModelToolResultContent::Image(image) => {
 429                            vec![open_ai::MessagePart::Image {
 430                                image_url: ImageUrl {
 431                                    url: image.to_base64_url(),
 432                                    detail: None,
 433                                },
 434                            }]
 435                        }
 436                    };
 437
 438                    messages.push(open_ai::RequestMessage::Tool {
 439                        content: content.into(),
 440                        tool_call_id: tool_result.tool_use_id.to_string(),
 441                    });
 442                }
 443            }
 444        }
 445    }
 446
 447    open_ai::Request {
 448        model: model_id.into(),
 449        messages,
 450        stream,
 451        stop: request.stop,
 452        temperature: request.temperature.unwrap_or(1.0),
 453        max_completion_tokens: max_output_tokens,
 454        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
 455            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
 456            Some(false)
 457        } else {
 458            None
 459        },
 460        tools: request
 461            .tools
 462            .into_iter()
 463            .map(|tool| open_ai::ToolDefinition::Function {
 464                function: open_ai::FunctionDefinition {
 465                    name: tool.name,
 466                    description: Some(tool.description),
 467                    parameters: Some(tool.input_schema),
 468                },
 469            })
 470            .collect(),
 471        tool_choice: request.tool_choice.map(|choice| match choice {
 472            LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
 473            LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
 474            LanguageModelToolChoice::None => open_ai::ToolChoice::None,
 475        }),
 476    }
 477}
 478
 479fn add_message_content_part(
 480    new_part: open_ai::MessagePart,
 481    role: Role,
 482    messages: &mut Vec<open_ai::RequestMessage>,
 483) {
 484    match (role, messages.last_mut()) {
 485        (Role::User, Some(open_ai::RequestMessage::User { content }))
 486        | (
 487            Role::Assistant,
 488            Some(open_ai::RequestMessage::Assistant {
 489                content: Some(content),
 490                ..
 491            }),
 492        )
 493        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
 494            content.push_part(new_part);
 495        }
 496        _ => {
 497            messages.push(match role {
 498                Role::User => open_ai::RequestMessage::User {
 499                    content: open_ai::MessageContent::from(vec![new_part]),
 500                },
 501                Role::Assistant => open_ai::RequestMessage::Assistant {
 502                    content: Some(open_ai::MessageContent::from(vec![new_part])),
 503                    tool_calls: Vec::new(),
 504                },
 505                Role::System => open_ai::RequestMessage::System {
 506                    content: open_ai::MessageContent::from(vec![new_part]),
 507                },
 508            });
 509        }
 510    }
 511}
 512
 513pub struct OpenAiEventMapper {
 514    tool_calls_by_index: HashMap<usize, RawToolCall>,
 515}
 516
 517impl OpenAiEventMapper {
 518    pub fn new() -> Self {
 519        Self {
 520            tool_calls_by_index: HashMap::default(),
 521        }
 522    }
 523
 524    pub fn map_stream(
 525        mut self,
 526        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
 527    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 528    {
 529        events.flat_map(move |event| {
 530            futures::stream::iter(match event {
 531                Ok(event) => self.map_event(event),
 532                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
 533            })
 534        })
 535    }
 536
 537    pub fn map_event(
 538        &mut self,
 539        event: ResponseStreamEvent,
 540    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 541        let mut events = Vec::new();
 542        if let Some(usage) = event.usage {
 543            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 544                input_tokens: usage.prompt_tokens,
 545                output_tokens: usage.completion_tokens,
 546                cache_creation_input_tokens: 0,
 547                cache_read_input_tokens: 0,
 548            })));
 549        }
 550
 551        let Some(choice) = event.choices.first() else {
 552            return events;
 553        };
 554
 555        if let Some(content) = choice.delta.content.clone() {
 556            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 557        }
 558
 559        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 560            for tool_call in tool_calls {
 561                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 562
 563                if let Some(tool_id) = tool_call.id.clone() {
 564                    entry.id = tool_id;
 565                }
 566
 567                if let Some(function) = tool_call.function.as_ref() {
 568                    if let Some(name) = function.name.clone() {
 569                        entry.name = name;
 570                    }
 571
 572                    if let Some(arguments) = function.arguments.clone() {
 573                        entry.arguments.push_str(&arguments);
 574                    }
 575                }
 576            }
 577        }
 578
 579        match choice.finish_reason.as_deref() {
 580            Some("stop") => {
 581                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 582            }
 583            Some("tool_calls") => {
 584                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
 585                    match serde_json::Value::from_str(&tool_call.arguments) {
 586                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 587                            LanguageModelToolUse {
 588                                id: tool_call.id.clone().into(),
 589                                name: tool_call.name.as_str().into(),
 590                                is_input_complete: true,
 591                                input,
 592                                raw_input: tool_call.arguments.clone(),
 593                            },
 594                        )),
 595                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 596                            id: tool_call.id.into(),
 597                            tool_name: tool_call.name.into(),
 598                            raw_input: tool_call.arguments.clone().into(),
 599                            json_parse_error: error.to_string(),
 600                        }),
 601                    }
 602                }));
 603
 604                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 605            }
 606            Some(stop_reason) => {
 607                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
 608                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 609            }
 610            None => {}
 611        }
 612
 613        events
 614    }
 615}
 616
 617#[derive(Default)]
 618struct RawToolCall {
 619    id: String,
 620    name: String,
 621    arguments: String,
 622}
 623
 624pub fn count_open_ai_tokens(
 625    request: LanguageModelRequest,
 626    model: Model,
 627    cx: &App,
 628) -> BoxFuture<'static, Result<u64>> {
 629    cx.background_spawn(async move {
 630        let messages = request
 631            .messages
 632            .into_iter()
 633            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 634                role: match message.role {
 635                    Role::User => "user".into(),
 636                    Role::Assistant => "assistant".into(),
 637                    Role::System => "system".into(),
 638                },
 639                content: Some(message.string_contents()),
 640                name: None,
 641                function_call: None,
 642            })
 643            .collect::<Vec<_>>();
 644
 645        match model {
 646            Model::Custom { max_tokens, .. } => {
 647                let model = if max_tokens >= 100_000 {
 648                    // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
 649                    "gpt-4o"
 650                } else {
 651                    // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
 652                    // supported with this tiktoken method
 653                    "gpt-4"
 654                };
 655                tiktoken_rs::num_tokens_from_messages(model, &messages)
 656            }
 657            // Currently supported by tiktoken_rs
 658            // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
 659            // arm with an override. We enumerate all supported models here so that we can check if new
 660            // models are supported yet or not.
 661            Model::ThreePointFiveTurbo
 662            | Model::Four
 663            | Model::FourTurbo
 664            | Model::FourOmni
 665            | Model::FourOmniMini
 666            | Model::FourPointOne
 667            | Model::FourPointOneMini
 668            | Model::FourPointOneNano
 669            | Model::O1
 670            | Model::O3
 671            | Model::O3Mini
 672            | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
 673        }
 674        .map(|tokens| tokens as u64)
 675    })
 676    .boxed()
 677}
 678
 679struct ConfigurationView {
 680    api_key_editor: Entity<SingleLineInput>,
 681    api_url_editor: Entity<SingleLineInput>,
 682    state: gpui::Entity<State>,
 683    load_credentials_task: Option<Task<()>>,
 684}
 685
 686impl ConfigurationView {
 687    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 688        let api_key_editor = cx.new(|cx| {
 689            SingleLineInput::new(
 690                window,
 691                cx,
 692                "sk-000000000000000000000000000000000000000000000000",
 693            )
 694            .label("API key")
 695        });
 696
 697        let api_url = AllLanguageModelSettings::get_global(cx)
 698            .openai
 699            .api_url
 700            .clone();
 701
 702        let api_url_editor = cx.new(|cx| {
 703            let input = SingleLineInput::new(window, cx, open_ai::OPEN_AI_API_URL).label("API URL");
 704
 705            if !api_url.is_empty() {
 706                input.editor.update(cx, |editor, cx| {
 707                    editor.set_text(&*api_url, window, cx);
 708                });
 709            }
 710            input
 711        });
 712
 713        cx.observe(&state, |_, _, cx| {
 714            cx.notify();
 715        })
 716        .detach();
 717
 718        let load_credentials_task = Some(cx.spawn_in(window, {
 719            let state = state.clone();
 720            async move |this, cx| {
 721                if let Some(task) = state
 722                    .update(cx, |state, cx| state.authenticate(cx))
 723                    .log_err()
 724                {
 725                    // We don't log an error, because "not signed in" is also an error.
 726                    let _ = task.await;
 727                }
 728                this.update(cx, |this, cx| {
 729                    this.load_credentials_task = None;
 730                    cx.notify();
 731                })
 732                .log_err();
 733            }
 734        }));
 735
 736        Self {
 737            api_key_editor,
 738            api_url_editor,
 739            state,
 740            load_credentials_task,
 741        }
 742    }
 743
 744    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 745        let api_key = self
 746            .api_key_editor
 747            .read(cx)
 748            .editor()
 749            .read(cx)
 750            .text(cx)
 751            .trim()
 752            .to_string();
 753
 754        // Don't proceed if no API key is provided and we're not authenticated
 755        if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
 756            return;
 757        }
 758
 759        let state = self.state.clone();
 760        cx.spawn_in(window, async move |_, cx| {
 761            state
 762                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 763                .await
 764        })
 765        .detach_and_log_err(cx);
 766
 767        cx.notify();
 768    }
 769
 770    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 771        self.api_key_editor.update(cx, |input, cx| {
 772            input.editor.update(cx, |editor, cx| {
 773                editor.set_text("", window, cx);
 774            });
 775        });
 776
 777        let state = self.state.clone();
 778        cx.spawn_in(window, async move |_, cx| {
 779            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 780        })
 781        .detach_and_log_err(cx);
 782
 783        cx.notify();
 784    }
 785
 786    fn save_api_url(&mut self, cx: &mut Context<Self>) {
 787        let api_url = self
 788            .api_url_editor
 789            .read(cx)
 790            .editor()
 791            .read(cx)
 792            .text(cx)
 793            .trim()
 794            .to_string();
 795
 796        let current_url = AllLanguageModelSettings::get_global(cx)
 797            .openai
 798            .api_url
 799            .clone();
 800
 801        let effective_current_url = if current_url.is_empty() {
 802            open_ai::OPEN_AI_API_URL
 803        } else {
 804            &current_url
 805        };
 806
 807        if !api_url.is_empty() && api_url != effective_current_url {
 808            let fs = <dyn Fs>::global(cx);
 809            update_settings_file::<AllLanguageModelSettings>(fs, cx, move |settings, _| {
 810                if let Some(settings) = settings.openai.as_mut() {
 811                    settings.api_url = Some(api_url.clone());
 812                } else {
 813                    settings.openai = Some(OpenAiSettingsContent {
 814                        api_url: Some(api_url.clone()),
 815                        available_models: None,
 816                    });
 817                }
 818            });
 819        }
 820    }
 821
 822    fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 823        self.api_url_editor.update(cx, |input, cx| {
 824            input.editor.update(cx, |editor, cx| {
 825                editor.set_text("", window, cx);
 826            });
 827        });
 828        let fs = <dyn Fs>::global(cx);
 829        update_settings_file::<AllLanguageModelSettings>(fs, cx, |settings, _cx| {
 830            if let Some(settings) = settings.openai.as_mut() {
 831                settings.api_url = None;
 832            }
 833        });
 834        cx.notify();
 835    }
 836
 837    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 838        !self.state.read(cx).is_authenticated()
 839    }
 840}
 841
 842impl Render for ConfigurationView {
 843    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 844        let env_var_set = self.state.read(cx).api_key_from_env;
 845
 846        let api_key_section = if self.should_render_editor(cx) {
 847            v_flex()
 848                .on_action(cx.listener(Self::save_api_key))
 849
 850                .child(Label::new("To use Zed's assistant with OpenAI, you need to add an API key. Follow these steps:"))
 851                .child(
 852                    List::new()
 853                        .child(InstructionListItem::new(
 854                            "Create one by visiting",
 855                            Some("OpenAI's console"),
 856                            Some("https://platform.openai.com/api-keys"),
 857                        ))
 858                        .child(InstructionListItem::text_only(
 859                            "Ensure your OpenAI account has credits",
 860                        ))
 861                        .child(InstructionListItem::text_only(
 862                            "Paste your API key below and hit enter to start using the assistant",
 863                        )),
 864                )
 865                .child(self.api_key_editor.clone())
 866                .child(
 867                    Label::new(
 868                        format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
 869                    )
 870                    .size(LabelSize::Small).color(Color::Muted),
 871                )
 872                .child(
 873                    Label::new(
 874                        "Note that having a subscription for another service like GitHub Copilot won't work.",
 875                    )
 876                    .size(LabelSize::Small).color(Color::Muted),
 877                )
 878                .into_any()
 879        } else {
 880            h_flex()
 881                .mt_1()
 882                .p_1()
 883                .justify_between()
 884                .rounded_md()
 885                .border_1()
 886                .border_color(cx.theme().colors().border)
 887                .bg(cx.theme().colors().background)
 888                .child(
 889                    h_flex()
 890                        .gap_1()
 891                        .child(Icon::new(IconName::Check).color(Color::Success))
 892                        .child(Label::new(if env_var_set {
 893                            format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
 894                        } else {
 895                            "API key configured.".to_string()
 896                        })),
 897                )
 898                .child(
 899                    Button::new("reset-api-key", "Reset API Key")
 900                        .label_size(LabelSize::Small)
 901                        .icon(IconName::Undo)
 902                        .icon_size(IconSize::Small)
 903                        .icon_position(IconPosition::Start)
 904                        .layer(ElevationIndex::ModalSurface)
 905                        .when(env_var_set, |this| {
 906                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
 907                        })
 908                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
 909                )
 910                .into_any()
 911        };
 912
 913        let custom_api_url_set =
 914            AllLanguageModelSettings::get_global(cx).openai.api_url != open_ai::OPEN_AI_API_URL;
 915
 916        let api_url_section = if custom_api_url_set {
 917            h_flex()
 918                .mt_1()
 919                .p_1()
 920                .justify_between()
 921                .rounded_md()
 922                .border_1()
 923                .border_color(cx.theme().colors().border)
 924                .bg(cx.theme().colors().background)
 925                .child(
 926                    h_flex()
 927                        .gap_1()
 928                        .child(Icon::new(IconName::Check).color(Color::Success))
 929                        .child(Label::new("Custom API URL configured.")),
 930                )
 931                .child(
 932                    Button::new("reset-api-url", "Reset API URL")
 933                        .label_size(LabelSize::Small)
 934                        .icon(IconName::Undo)
 935                        .icon_size(IconSize::Small)
 936                        .icon_position(IconPosition::Start)
 937                        .layer(ElevationIndex::ModalSurface)
 938                        .on_click(
 939                            cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
 940                        ),
 941                )
 942                .into_any()
 943        } else {
 944            v_flex()
 945                .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
 946                    this.save_api_url(cx);
 947                    cx.notify();
 948                }))
 949                .mt_2()
 950                .pt_2()
 951                .border_t_1()
 952                .border_color(cx.theme().colors().border_variant)
 953                .gap_1()
 954                .child(
 955                    List::new()
 956                        .child(InstructionListItem::text_only(
 957                            "Optionally, you can change the base URL for the OpenAI API request.",
 958                        ))
 959                        .child(InstructionListItem::text_only(
 960                            "Paste the new API endpoint below and hit enter",
 961                        )),
 962                )
 963                .child(self.api_url_editor.clone())
 964                .into_any()
 965        };
 966
 967        if self.load_credentials_task.is_some() {
 968            div().child(Label::new("Loading credentials…")).into_any()
 969        } else {
 970            v_flex()
 971                .size_full()
 972                .child(api_key_section)
 973                .child(api_url_section)
 974                .into_any()
 975        }
 976    }
 977}
 978
 979#[cfg(test)]
 980mod tests {
 981    use gpui::TestAppContext;
 982    use language_model::LanguageModelRequestMessage;
 983
 984    use super::*;
 985
 986    #[gpui::test]
 987    fn tiktoken_rs_support(cx: &TestAppContext) {
 988        let request = LanguageModelRequest {
 989            thread_id: None,
 990            prompt_id: None,
 991            intent: None,
 992            mode: None,
 993            messages: vec![LanguageModelRequestMessage {
 994                role: Role::User,
 995                content: vec![MessageContent::Text("message".into())],
 996                cache: false,
 997            }],
 998            tools: vec![],
 999            tool_choice: None,
1000            stop: vec![],
1001            temperature: None,
1002        };
1003
1004        // Validate that all models are supported by tiktoken-rs
1005        for model in Model::iter() {
1006            let count = cx
1007                .executor()
1008                .block(count_open_ai_tokens(
1009                    request.clone(),
1010                    model,
1011                    &cx.app.borrow(),
1012                ))
1013                .unwrap();
1014            assert!(count > 0);
1015        }
1016    }
1017}