open_ai.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use collections::{BTreeMap, HashMap};
   3use credentials_provider::CredentialsProvider;
   4
   5use futures::Stream;
   6use futures::{FutureExt, StreamExt, future::BoxFuture};
   7use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
   8use http_client::HttpClient;
   9use language_model::{
  10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
  13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  14    RateLimiter, Role, StopReason, TokenUsage,
  15};
  16use menu;
  17use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
  18use schemars::JsonSchema;
  19use serde::{Deserialize, Serialize};
  20use settings::{Settings, SettingsStore};
  21use std::pin::Pin;
  22use std::str::FromStr as _;
  23use std::sync::Arc;
  24use strum::IntoEnumIterator;
  25
  26use ui::{ElevationIndex, List, Tooltip, prelude::*};
  27use ui_input::SingleLineInput;
  28use util::ResultExt;
  29
  30use crate::{AllLanguageModelSettings, ui::InstructionListItem};
  31
  32const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
  33const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
  34
  35#[derive(Default, Clone, Debug, PartialEq)]
  36pub struct OpenAiSettings {
  37    pub api_url: String,
  38    pub available_models: Vec<AvailableModel>,
  39}
  40
  41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
  42pub struct AvailableModel {
  43    pub name: String,
  44    pub display_name: Option<String>,
  45    pub max_tokens: u64,
  46    pub max_output_tokens: Option<u64>,
  47    pub max_completion_tokens: Option<u64>,
  48    pub reasoning_effort: Option<ReasoningEffort>,
  49}
  50
  51pub struct OpenAiLanguageModelProvider {
  52    http_client: Arc<dyn HttpClient>,
  53    state: gpui::Entity<State>,
  54}
  55
  56pub struct State {
  57    api_key: Option<String>,
  58    api_key_from_env: bool,
  59    last_api_url: String,
  60    _subscription: Subscription,
  61}
  62
  63const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
  64
  65impl State {
  66    fn is_authenticated(&self) -> bool {
  67        self.api_key.is_some()
  68    }
  69
  70    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
  71        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  72        let api_url = AllLanguageModelSettings::get_global(cx)
  73            .openai
  74            .api_url
  75            .clone();
  76        cx.spawn(async move |this, cx| {
  77            credentials_provider
  78                .delete_credentials(&api_url, cx)
  79                .await
  80                .log_err();
  81            this.update(cx, |this, cx| {
  82                this.api_key = None;
  83                this.api_key_from_env = false;
  84                cx.notify();
  85            })
  86        })
  87    }
  88
  89    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
  90        let credentials_provider = <dyn CredentialsProvider>::global(cx);
  91        let api_url = AllLanguageModelSettings::get_global(cx)
  92            .openai
  93            .api_url
  94            .clone();
  95        cx.spawn(async move |this, cx| {
  96            credentials_provider
  97                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
  98                .await
  99                .log_err();
 100            this.update(cx, |this, cx| {
 101                this.api_key = Some(api_key);
 102                cx.notify();
 103            })
 104        })
 105    }
 106
 107    fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 108        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 109        let api_url = AllLanguageModelSettings::get_global(cx)
 110            .openai
 111            .api_url
 112            .clone();
 113        cx.spawn(async move |this, cx| {
 114            let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
 115                (api_key, true)
 116            } else {
 117                let (_, api_key) = credentials_provider
 118                    .read_credentials(&api_url, cx)
 119                    .await?
 120                    .ok_or(AuthenticateError::CredentialsNotFound)?;
 121                (
 122                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
 123                    false,
 124                )
 125            };
 126            this.update(cx, |this, cx| {
 127                this.api_key = Some(api_key);
 128                this.api_key_from_env = from_env;
 129                cx.notify();
 130            })?;
 131
 132            Ok(())
 133        })
 134    }
 135
 136    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 137        if self.is_authenticated() {
 138            return Task::ready(Ok(()));
 139        }
 140
 141        self.get_api_key(cx)
 142    }
 143}
 144
 145impl OpenAiLanguageModelProvider {
 146    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 147        let initial_api_url = AllLanguageModelSettings::get_global(cx)
 148            .openai
 149            .api_url
 150            .clone();
 151
 152        let state = cx.new(|cx| State {
 153            api_key: None,
 154            api_key_from_env: false,
 155            last_api_url: initial_api_url.clone(),
 156            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 157                let current_api_url = AllLanguageModelSettings::get_global(cx)
 158                    .openai
 159                    .api_url
 160                    .clone();
 161
 162                if this.last_api_url != current_api_url {
 163                    this.last_api_url = current_api_url;
 164                    if !this.api_key_from_env {
 165                        this.api_key = None;
 166                        let spawn_task = cx.spawn(async move |handle, cx| {
 167                            if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
 168                                if let Err(_) = task.await {
 169                                    handle
 170                                        .update(cx, |this, _| {
 171                                            this.api_key = None;
 172                                            this.api_key_from_env = false;
 173                                        })
 174                                        .ok();
 175                                }
 176                            }
 177                        });
 178                        spawn_task.detach();
 179                    }
 180                }
 181                cx.notify();
 182            }),
 183        });
 184
 185        Self { http_client, state }
 186    }
 187
 188    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
 189        Arc::new(OpenAiLanguageModel {
 190            id: LanguageModelId::from(model.id().to_string()),
 191            model,
 192            state: self.state.clone(),
 193            http_client: self.http_client.clone(),
 194            request_limiter: RateLimiter::new(4),
 195        })
 196    }
 197}
 198
 199impl LanguageModelProviderState for OpenAiLanguageModelProvider {
 200    type ObservableEntity = State;
 201
 202    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
 203        Some(self.state.clone())
 204    }
 205}
 206
 207impl LanguageModelProvider for OpenAiLanguageModelProvider {
 208    fn id(&self) -> LanguageModelProviderId {
 209        PROVIDER_ID
 210    }
 211
 212    fn name(&self) -> LanguageModelProviderName {
 213        PROVIDER_NAME
 214    }
 215
 216    fn icon(&self) -> IconName {
 217        IconName::AiOpenAi
 218    }
 219
 220    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 221        Some(self.create_language_model(open_ai::Model::default()))
 222    }
 223
 224    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
 225        Some(self.create_language_model(open_ai::Model::default_fast()))
 226    }
 227
 228    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 229        let mut models = BTreeMap::default();
 230
 231        // Add base models from open_ai::Model::iter()
 232        for model in open_ai::Model::iter() {
 233            if !matches!(model, open_ai::Model::Custom { .. }) {
 234                models.insert(model.id().to_string(), model);
 235            }
 236        }
 237
 238        // Override with available models from settings
 239        for model in &AllLanguageModelSettings::get_global(cx)
 240            .openai
 241            .available_models
 242        {
 243            models.insert(
 244                model.name.clone(),
 245                open_ai::Model::Custom {
 246                    name: model.name.clone(),
 247                    display_name: model.display_name.clone(),
 248                    max_tokens: model.max_tokens,
 249                    max_output_tokens: model.max_output_tokens,
 250                    max_completion_tokens: model.max_completion_tokens,
 251                    reasoning_effort: model.reasoning_effort.clone(),
 252                },
 253            );
 254        }
 255
 256        models
 257            .into_values()
 258            .map(|model| self.create_language_model(model))
 259            .collect()
 260    }
 261
 262    fn is_authenticated(&self, cx: &App) -> bool {
 263        self.state.read(cx).is_authenticated()
 264    }
 265
 266    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 267        self.state.update(cx, |state, cx| state.authenticate(cx))
 268    }
 269
 270    fn configuration_view(
 271        &self,
 272        _target_agent: language_model::ConfigurationViewTargetAgent,
 273        window: &mut Window,
 274        cx: &mut App,
 275    ) -> AnyView {
 276        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
 277            .into()
 278    }
 279
 280    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 281        self.state.update(cx, |state, cx| state.reset_api_key(cx))
 282    }
 283}
 284
 285pub struct OpenAiLanguageModel {
 286    id: LanguageModelId,
 287    model: open_ai::Model,
 288    state: gpui::Entity<State>,
 289    http_client: Arc<dyn HttpClient>,
 290    request_limiter: RateLimiter,
 291}
 292
 293impl OpenAiLanguageModel {
 294    fn stream_completion(
 295        &self,
 296        request: open_ai::Request,
 297        cx: &AsyncApp,
 298    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
 299    {
 300        let http_client = self.http_client.clone();
 301        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
 302            let settings = &AllLanguageModelSettings::get_global(cx).openai;
 303            (state.api_key.clone(), settings.api_url.clone())
 304        }) else {
 305            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
 306        };
 307
 308        let future = self.request_limiter.stream(async move {
 309            let Some(api_key) = api_key else {
 310                return Err(LanguageModelCompletionError::NoApiKey {
 311                    provider: PROVIDER_NAME,
 312                });
 313            };
 314            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
 315            let response = request.await?;
 316            Ok(response)
 317        });
 318
 319        async move { Ok(future.await?.boxed()) }.boxed()
 320    }
 321}
 322
 323impl LanguageModel for OpenAiLanguageModel {
 324    fn id(&self) -> LanguageModelId {
 325        self.id.clone()
 326    }
 327
 328    fn name(&self) -> LanguageModelName {
 329        LanguageModelName::from(self.model.display_name().to_string())
 330    }
 331
 332    fn provider_id(&self) -> LanguageModelProviderId {
 333        PROVIDER_ID
 334    }
 335
 336    fn provider_name(&self) -> LanguageModelProviderName {
 337        PROVIDER_NAME
 338    }
 339
 340    fn supports_tools(&self) -> bool {
 341        true
 342    }
 343
 344    fn supports_images(&self) -> bool {
 345        use open_ai::Model;
 346        match &self.model {
 347            Model::FourOmni
 348            | Model::FourOmniMini
 349            | Model::FourPointOne
 350            | Model::FourPointOneMini
 351            | Model::FourPointOneNano
 352            | Model::Five
 353            | Model::FiveMini
 354            | Model::FiveNano
 355            | Model::O1
 356            | Model::O3
 357            | Model::O4Mini => true,
 358            Model::ThreePointFiveTurbo
 359            | Model::Four
 360            | Model::FourTurbo
 361            | Model::O3Mini
 362            | Model::Custom { .. } => false,
 363        }
 364    }
 365
 366    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 367        match choice {
 368            LanguageModelToolChoice::Auto => true,
 369            LanguageModelToolChoice::Any => true,
 370            LanguageModelToolChoice::None => true,
 371        }
 372    }
 373
 374    fn telemetry_id(&self) -> String {
 375        format!("openai/{}", self.model.id())
 376    }
 377
 378    fn max_token_count(&self) -> u64 {
 379        self.model.max_token_count()
 380    }
 381
 382    fn max_output_tokens(&self) -> Option<u64> {
 383        self.model.max_output_tokens()
 384    }
 385
 386    fn count_tokens(
 387        &self,
 388        request: LanguageModelRequest,
 389        cx: &App,
 390    ) -> BoxFuture<'static, Result<u64>> {
 391        count_open_ai_tokens(request, self.model.clone(), cx)
 392    }
 393
 394    fn stream_completion(
 395        &self,
 396        request: LanguageModelRequest,
 397        cx: &AsyncApp,
 398    ) -> BoxFuture<
 399        'static,
 400        Result<
 401            futures::stream::BoxStream<
 402                'static,
 403                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
 404            >,
 405            LanguageModelCompletionError,
 406        >,
 407    > {
 408        let request = into_open_ai(
 409            request,
 410            self.model.id(),
 411            self.model.supports_parallel_tool_calls(),
 412            self.model.supports_prompt_cache_key(),
 413            self.max_output_tokens(),
 414            self.model.reasoning_effort(),
 415        );
 416        let completions = self.stream_completion(request, cx);
 417        async move {
 418            let mapper = OpenAiEventMapper::new();
 419            Ok(mapper.map_stream(completions.await?).boxed())
 420        }
 421        .boxed()
 422    }
 423}
 424
 425pub fn into_open_ai(
 426    request: LanguageModelRequest,
 427    model_id: &str,
 428    supports_parallel_tool_calls: bool,
 429    supports_prompt_cache_key: bool,
 430    max_output_tokens: Option<u64>,
 431    reasoning_effort: Option<ReasoningEffort>,
 432) -> open_ai::Request {
 433    let stream = !model_id.starts_with("o1-");
 434
 435    let mut messages = Vec::new();
 436    for message in request.messages {
 437        for content in message.content {
 438            match content {
 439                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
 440                    add_message_content_part(
 441                        open_ai::MessagePart::Text { text },
 442                        message.role,
 443                        &mut messages,
 444                    )
 445                }
 446                MessageContent::RedactedThinking(_) => {}
 447                MessageContent::Image(image) => {
 448                    add_message_content_part(
 449                        open_ai::MessagePart::Image {
 450                            image_url: ImageUrl {
 451                                url: image.to_base64_url(),
 452                                detail: None,
 453                            },
 454                        },
 455                        message.role,
 456                        &mut messages,
 457                    );
 458                }
 459                MessageContent::ToolUse(tool_use) => {
 460                    let tool_call = open_ai::ToolCall {
 461                        id: tool_use.id.to_string(),
 462                        content: open_ai::ToolCallContent::Function {
 463                            function: open_ai::FunctionContent {
 464                                name: tool_use.name.to_string(),
 465                                arguments: serde_json::to_string(&tool_use.input)
 466                                    .unwrap_or_default(),
 467                            },
 468                        },
 469                    };
 470
 471                    if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
 472                        messages.last_mut()
 473                    {
 474                        tool_calls.push(tool_call);
 475                    } else {
 476                        messages.push(open_ai::RequestMessage::Assistant {
 477                            content: None,
 478                            tool_calls: vec![tool_call],
 479                        });
 480                    }
 481                }
 482                MessageContent::ToolResult(tool_result) => {
 483                    let content = match &tool_result.content {
 484                        LanguageModelToolResultContent::Text(text) => {
 485                            vec![open_ai::MessagePart::Text {
 486                                text: text.to_string(),
 487                            }]
 488                        }
 489                        LanguageModelToolResultContent::Image(image) => {
 490                            vec![open_ai::MessagePart::Image {
 491                                image_url: ImageUrl {
 492                                    url: image.to_base64_url(),
 493                                    detail: None,
 494                                },
 495                            }]
 496                        }
 497                    };
 498
 499                    messages.push(open_ai::RequestMessage::Tool {
 500                        content: content.into(),
 501                        tool_call_id: tool_result.tool_use_id.to_string(),
 502                    });
 503                }
 504            }
 505        }
 506    }
 507
 508    open_ai::Request {
 509        model: model_id.into(),
 510        messages,
 511        stream,
 512        stop: request.stop,
 513        temperature: request.temperature.unwrap_or(1.0),
 514        max_completion_tokens: max_output_tokens,
 515        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
 516            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
 517            Some(false)
 518        } else {
 519            None
 520        },
 521        prompt_cache_key: if supports_prompt_cache_key {
 522            request.thread_id
 523        } else {
 524            None
 525        },
 526        tools: request
 527            .tools
 528            .into_iter()
 529            .map(|tool| open_ai::ToolDefinition::Function {
 530                function: open_ai::FunctionDefinition {
 531                    name: tool.name,
 532                    description: Some(tool.description),
 533                    parameters: Some(tool.input_schema),
 534                },
 535            })
 536            .collect(),
 537        tool_choice: request.tool_choice.map(|choice| match choice {
 538            LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
 539            LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
 540            LanguageModelToolChoice::None => open_ai::ToolChoice::None,
 541        }),
 542        reasoning_effort,
 543    }
 544}
 545
 546fn add_message_content_part(
 547    new_part: open_ai::MessagePart,
 548    role: Role,
 549    messages: &mut Vec<open_ai::RequestMessage>,
 550) {
 551    match (role, messages.last_mut()) {
 552        (Role::User, Some(open_ai::RequestMessage::User { content }))
 553        | (
 554            Role::Assistant,
 555            Some(open_ai::RequestMessage::Assistant {
 556                content: Some(content),
 557                ..
 558            }),
 559        )
 560        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
 561            content.push_part(new_part);
 562        }
 563        _ => {
 564            messages.push(match role {
 565                Role::User => open_ai::RequestMessage::User {
 566                    content: open_ai::MessageContent::from(vec![new_part]),
 567                },
 568                Role::Assistant => open_ai::RequestMessage::Assistant {
 569                    content: Some(open_ai::MessageContent::from(vec![new_part])),
 570                    tool_calls: Vec::new(),
 571                },
 572                Role::System => open_ai::RequestMessage::System {
 573                    content: open_ai::MessageContent::from(vec![new_part]),
 574                },
 575            });
 576        }
 577    }
 578}
 579
 580pub struct OpenAiEventMapper {
 581    tool_calls_by_index: HashMap<usize, RawToolCall>,
 582}
 583
 584impl OpenAiEventMapper {
 585    pub fn new() -> Self {
 586        Self {
 587            tool_calls_by_index: HashMap::default(),
 588        }
 589    }
 590
 591    pub fn map_stream(
 592        mut self,
 593        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
 594    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
 595    {
 596        events.flat_map(move |event| {
 597            futures::stream::iter(match event {
 598                Ok(event) => self.map_event(event),
 599                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
 600            })
 601        })
 602    }
 603
 604    pub fn map_event(
 605        &mut self,
 606        event: ResponseStreamEvent,
 607    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 608        let mut events = Vec::new();
 609        if let Some(usage) = event.usage {
 610            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 611                input_tokens: usage.prompt_tokens,
 612                output_tokens: usage.completion_tokens,
 613                cache_creation_input_tokens: 0,
 614                cache_read_input_tokens: 0,
 615            })));
 616        }
 617
 618        let Some(choice) = event.choices.first() else {
 619            return events;
 620        };
 621
 622        if let Some(content) = choice.delta.content.clone() {
 623            if !content.is_empty() {
 624                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 625            }
 626        }
 627
 628        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
 629            for tool_call in tool_calls {
 630                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
 631
 632                if let Some(tool_id) = tool_call.id.clone() {
 633                    entry.id = tool_id;
 634                }
 635
 636                if let Some(function) = tool_call.function.as_ref() {
 637                    if let Some(name) = function.name.clone() {
 638                        entry.name = name;
 639                    }
 640
 641                    if let Some(arguments) = function.arguments.clone() {
 642                        entry.arguments.push_str(&arguments);
 643                    }
 644                }
 645            }
 646        }
 647
 648        match choice.finish_reason.as_deref() {
 649            Some("stop") => {
 650                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 651            }
 652            Some("tool_calls") => {
 653                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
 654                    match serde_json::Value::from_str(&tool_call.arguments) {
 655                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
 656                            LanguageModelToolUse {
 657                                id: tool_call.id.clone().into(),
 658                                name: tool_call.name.as_str().into(),
 659                                is_input_complete: true,
 660                                input,
 661                                raw_input: tool_call.arguments.clone(),
 662                            },
 663                        )),
 664                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
 665                            id: tool_call.id.into(),
 666                            tool_name: tool_call.name.into(),
 667                            raw_input: tool_call.arguments.clone().into(),
 668                            json_parse_error: error.to_string(),
 669                        }),
 670                    }
 671                }));
 672
 673                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 674            }
 675            Some(stop_reason) => {
 676                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
 677                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 678            }
 679            None => {}
 680        }
 681
 682        events
 683    }
 684}
 685
 686#[derive(Default)]
 687struct RawToolCall {
 688    id: String,
 689    name: String,
 690    arguments: String,
 691}
 692
 693pub(crate) fn collect_tiktoken_messages(
 694    request: LanguageModelRequest,
 695) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
 696    request
 697        .messages
 698        .into_iter()
 699        .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
 700            role: match message.role {
 701                Role::User => "user".into(),
 702                Role::Assistant => "assistant".into(),
 703                Role::System => "system".into(),
 704            },
 705            content: Some(message.string_contents()),
 706            name: None,
 707            function_call: None,
 708        })
 709        .collect::<Vec<_>>()
 710}
 711
 712pub fn count_open_ai_tokens(
 713    request: LanguageModelRequest,
 714    model: Model,
 715    cx: &App,
 716) -> BoxFuture<'static, Result<u64>> {
 717    cx.background_spawn(async move {
 718        let messages = collect_tiktoken_messages(request);
 719
 720        match model {
 721            Model::Custom { max_tokens, .. } => {
 722                let model = if max_tokens >= 100_000 {
 723                    // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
 724                    "gpt-4o"
 725                } else {
 726                    // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
 727                    // supported with this tiktoken method
 728                    "gpt-4"
 729                };
 730                tiktoken_rs::num_tokens_from_messages(model, &messages)
 731            }
 732            // Currently supported by tiktoken_rs
 733            // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
 734            // arm with an override. We enumerate all supported models here so that we can check if new
 735            // models are supported yet or not.
 736            Model::ThreePointFiveTurbo
 737            | Model::Four
 738            | Model::FourTurbo
 739            | Model::FourOmni
 740            | Model::FourOmniMini
 741            | Model::FourPointOne
 742            | Model::FourPointOneMini
 743            | Model::FourPointOneNano
 744            | Model::O1
 745            | Model::O3
 746            | Model::O3Mini
 747            | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
 748            // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
 749            Model::Five | Model::FiveMini | Model::FiveNano => {
 750                tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
 751            }
 752        }
 753        .map(|tokens| tokens as u64)
 754    })
 755    .boxed()
 756}
 757
 758struct ConfigurationView {
 759    api_key_editor: Entity<SingleLineInput>,
 760    state: gpui::Entity<State>,
 761    load_credentials_task: Option<Task<()>>,
 762}
 763
 764impl ConfigurationView {
 765    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 766        let api_key_editor = cx.new(|cx| {
 767            SingleLineInput::new(
 768                window,
 769                cx,
 770                "sk-000000000000000000000000000000000000000000000000",
 771            )
 772        });
 773
 774        cx.observe(&state, |_, _, cx| {
 775            cx.notify();
 776        })
 777        .detach();
 778
 779        let load_credentials_task = Some(cx.spawn_in(window, {
 780            let state = state.clone();
 781            async move |this, cx| {
 782                if let Some(task) = state
 783                    .update(cx, |state, cx| state.authenticate(cx))
 784                    .log_err()
 785                {
 786                    // We don't log an error, because "not signed in" is also an error.
 787                    let _ = task.await;
 788                }
 789                this.update(cx, |this, cx| {
 790                    this.load_credentials_task = None;
 791                    cx.notify();
 792                })
 793                .log_err();
 794            }
 795        }));
 796
 797        Self {
 798            api_key_editor,
 799            state,
 800            load_credentials_task,
 801        }
 802    }
 803
 804    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 805        let api_key = self
 806            .api_key_editor
 807            .read(cx)
 808            .editor()
 809            .read(cx)
 810            .text(cx)
 811            .trim()
 812            .to_string();
 813
 814        // Don't proceed if no API key is provided and we're not authenticated
 815        if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
 816            return;
 817        }
 818
 819        let state = self.state.clone();
 820        cx.spawn_in(window, async move |_, cx| {
 821            state
 822                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
 823                .await
 824        })
 825        .detach_and_log_err(cx);
 826
 827        cx.notify();
 828    }
 829
 830    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 831        self.api_key_editor.update(cx, |input, cx| {
 832            input.editor.update(cx, |editor, cx| {
 833                editor.set_text("", window, cx);
 834            });
 835        });
 836
 837        let state = self.state.clone();
 838        cx.spawn_in(window, async move |_, cx| {
 839            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
 840        })
 841        .detach_and_log_err(cx);
 842
 843        cx.notify();
 844    }
 845
 846    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
 847        !self.state.read(cx).is_authenticated()
 848    }
 849}
 850
 851impl Render for ConfigurationView {
 852    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 853        let env_var_set = self.state.read(cx).api_key_from_env;
 854
 855        let api_key_section = if self.should_render_editor(cx) {
 856            v_flex()
 857                .on_action(cx.listener(Self::save_api_key))
 858                .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
 859                .child(
 860                    List::new()
 861                        .child(InstructionListItem::new(
 862                            "Create one by visiting",
 863                            Some("OpenAI's console"),
 864                            Some("https://platform.openai.com/api-keys"),
 865                        ))
 866                        .child(InstructionListItem::text_only(
 867                            "Ensure your OpenAI account has credits",
 868                        ))
 869                        .child(InstructionListItem::text_only(
 870                            "Paste your API key below and hit enter to start using the assistant",
 871                        )),
 872                )
 873                .child(self.api_key_editor.clone())
 874                .child(
 875                    Label::new(
 876                        format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
 877                    )
 878                    .size(LabelSize::Small).color(Color::Muted),
 879                )
 880                .child(
 881                    Label::new(
 882                        "Note that having a subscription for another service like GitHub Copilot won't work.",
 883                    )
 884                    .size(LabelSize::Small).color(Color::Muted),
 885                )
 886                .into_any()
 887        } else {
 888            h_flex()
 889                .mt_1()
 890                .p_1()
 891                .justify_between()
 892                .rounded_md()
 893                .border_1()
 894                .border_color(cx.theme().colors().border)
 895                .bg(cx.theme().colors().background)
 896                .child(
 897                    h_flex()
 898                        .gap_1()
 899                        .child(Icon::new(IconName::Check).color(Color::Success))
 900                        .child(Label::new(if env_var_set {
 901                            format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
 902                        } else {
 903                            "API key configured.".to_string()
 904                        })),
 905                )
 906                .child(
 907                    Button::new("reset-api-key", "Reset API Key")
 908                        .label_size(LabelSize::Small)
 909                        .icon(IconName::Undo)
 910                        .icon_size(IconSize::Small)
 911                        .icon_position(IconPosition::Start)
 912                        .layer(ElevationIndex::ModalSurface)
 913                        .when(env_var_set, |this| {
 914                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
 915                        })
 916                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
 917                )
 918                .into_any()
 919        };
 920
 921        let compatible_api_section = h_flex()
 922            .mt_1p5()
 923            .gap_0p5()
 924            .flex_wrap()
 925            .when(self.should_render_editor(cx), |this| {
 926                this.pt_1p5()
 927                    .border_t_1()
 928                    .border_color(cx.theme().colors().border_variant)
 929            })
 930            .child(
 931                h_flex()
 932                    .gap_2()
 933                    .child(
 934                        Icon::new(IconName::Info)
 935                            .size(IconSize::XSmall)
 936                            .color(Color::Muted),
 937                    )
 938                    .child(Label::new("Zed also supports OpenAI-compatible models.")),
 939            )
 940            .child(
 941                Button::new("docs", "Learn More")
 942                    .icon(IconName::ArrowUpRight)
 943                    .icon_size(IconSize::Small)
 944                    .icon_color(Color::Muted)
 945                    .on_click(move |_, _window, cx| {
 946                        cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
 947                    }),
 948            );
 949
 950        if self.load_credentials_task.is_some() {
 951            div().child(Label::new("Loading credentials…")).into_any()
 952        } else {
 953            v_flex()
 954                .size_full()
 955                .child(api_key_section)
 956                .child(compatible_api_section)
 957                .into_any()
 958        }
 959    }
 960}
 961
 962#[cfg(test)]
 963mod tests {
 964    use gpui::TestAppContext;
 965    use language_model::LanguageModelRequestMessage;
 966
 967    use super::*;
 968
 969    #[gpui::test]
 970    fn tiktoken_rs_support(cx: &TestAppContext) {
 971        let request = LanguageModelRequest {
 972            thread_id: None,
 973            prompt_id: None,
 974            intent: None,
 975            mode: None,
 976            messages: vec![LanguageModelRequestMessage {
 977                role: Role::User,
 978                content: vec![MessageContent::Text("message".into())],
 979                cache: false,
 980            }],
 981            tools: vec![],
 982            tool_choice: None,
 983            stop: vec![],
 984            temperature: None,
 985            thinking_allowed: true,
 986        };
 987
 988        // Validate that all models are supported by tiktoken-rs
 989        for model in Model::iter() {
 990            let count = cx
 991                .executor()
 992                .block(count_open_ai_tokens(
 993                    request.clone(),
 994                    model,
 995                    &cx.app.borrow(),
 996                ))
 997                .unwrap();
 998            assert!(count > 0);
 999        }
1000    }
1001}