open_ai.rs

  1use anyhow::{Result, anyhow};
  2use collections::{BTreeMap, HashMap};
  3use futures::Stream;
  4use futures::{FutureExt, StreamExt, future, future::BoxFuture};
  5use gpui::{AnyView, App, AsyncApp, Context, Entity, SharedString, Task, Window};
  6use http_client::HttpClient;
  7use language_model::{
  8    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  9    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 10    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 11    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
 12    RateLimiter, Role, StopReason, TokenUsage,
 13};
 14use menu;
 15use open_ai::{
 16    ImageUrl, Model, OPEN_AI_API_URL, ReasoningEffort, ResponseStreamEvent, stream_completion,
 17};
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use settings::{Settings, SettingsStore};
 21use std::pin::Pin;
 22use std::str::FromStr as _;
 23use std::sync::{Arc, LazyLock};
 24use strum::IntoEnumIterator;
 25use ui::{ElevationIndex, List, Tooltip, prelude::*};
 26use ui_input::SingleLineInput;
 27use util::{ResultExt, truncate_and_trailoff};
 28use zed_env_vars::{EnvVar, env_var};
 29
 30use crate::{api_key::ApiKeyState, ui::InstructionListItem};
 31
 32const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
 33const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
 34
 35const API_KEY_ENV_VAR_NAME: &str = "OPENAI_API_KEY";
 36static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
 37
 38#[derive(Default, Clone, Debug, PartialEq)]
 39pub struct OpenAiSettings {
 40    pub api_url: String,
 41    pub available_models: Vec<AvailableModel>,
 42}
 43
 44#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 45pub struct AvailableModel {
 46    pub name: String,
 47    pub display_name: Option<String>,
 48    pub max_tokens: u64,
 49    pub max_output_tokens: Option<u64>,
 50    pub max_completion_tokens: Option<u64>,
 51    pub reasoning_effort: Option<ReasoningEffort>,
 52}
 53
 54pub struct OpenAiLanguageModelProvider {
 55    http_client: Arc<dyn HttpClient>,
 56    state: gpui::Entity<State>,
 57}
 58
 59pub struct State {
 60    api_key_state: ApiKeyState,
 61}
 62
 63impl State {
 64    fn is_authenticated(&self) -> bool {
 65        self.api_key_state.has_key()
 66    }
 67
 68    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
 69        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 70        self.api_key_state
 71            .store(api_url, api_key, |this| &mut this.api_key_state, cx)
 72    }
 73
 74    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 75        let api_url = OpenAiLanguageModelProvider::api_url(cx);
 76        self.api_key_state.load_if_needed(
 77            api_url,
 78            &API_KEY_ENV_VAR,
 79            |this| &mut this.api_key_state,
 80            cx,
 81        )
 82    }
 83}
 84
 85impl OpenAiLanguageModelProvider {
 86    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 87        let state = cx.new(|cx| {
 88            cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 89                let api_url = Self::api_url(cx);
 90                this.api_key_state.handle_url_change(
 91                    api_url,
 92                    &API_KEY_ENV_VAR,
 93                    |this| &mut this.api_key_state,
 94                    cx,
 95                );
 96                cx.notify();
 97            })
 98            .detach();
 99            State {
100                api_key_state: ApiKeyState::new(Self::api_url(cx)),
101            }
102        });
103
104        Self { http_client, state }
105    }
106
107    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
108        Arc::new(OpenAiLanguageModel {
109            id: LanguageModelId::from(model.id().to_string()),
110            model,
111            state: self.state.clone(),
112            http_client: self.http_client.clone(),
113            request_limiter: RateLimiter::new(4),
114        })
115    }
116
117    fn settings(cx: &App) -> &OpenAiSettings {
118        &crate::AllLanguageModelSettings::get_global(cx).openai
119    }
120
121    fn api_url(cx: &App) -> SharedString {
122        let api_url = &Self::settings(cx).api_url;
123        if api_url.is_empty() {
124            open_ai::OPEN_AI_API_URL.into()
125        } else {
126            SharedString::new(api_url.as_str())
127        }
128    }
129}
130
131impl LanguageModelProviderState for OpenAiLanguageModelProvider {
132    type ObservableEntity = State;
133
134    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
135        Some(self.state.clone())
136    }
137}
138
139impl LanguageModelProvider for OpenAiLanguageModelProvider {
140    fn id(&self) -> LanguageModelProviderId {
141        PROVIDER_ID
142    }
143
144    fn name(&self) -> LanguageModelProviderName {
145        PROVIDER_NAME
146    }
147
148    fn icon(&self) -> IconName {
149        IconName::AiOpenAi
150    }
151
152    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
153        Some(self.create_language_model(open_ai::Model::default()))
154    }
155
156    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
157        Some(self.create_language_model(open_ai::Model::default_fast()))
158    }
159
160    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
161        let mut models = BTreeMap::default();
162
163        // Add base models from open_ai::Model::iter()
164        for model in open_ai::Model::iter() {
165            if !matches!(model, open_ai::Model::Custom { .. }) {
166                models.insert(model.id().to_string(), model);
167            }
168        }
169
170        // Override with available models from settings
171        for model in &OpenAiLanguageModelProvider::settings(cx).available_models {
172            models.insert(
173                model.name.clone(),
174                open_ai::Model::Custom {
175                    name: model.name.clone(),
176                    display_name: model.display_name.clone(),
177                    max_tokens: model.max_tokens,
178                    max_output_tokens: model.max_output_tokens,
179                    max_completion_tokens: model.max_completion_tokens,
180                    reasoning_effort: model.reasoning_effort.clone(),
181                },
182            );
183        }
184
185        models
186            .into_values()
187            .map(|model| self.create_language_model(model))
188            .collect()
189    }
190
191    fn is_authenticated(&self, cx: &App) -> bool {
192        self.state.read(cx).is_authenticated()
193    }
194
195    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
196        self.state.update(cx, |state, cx| state.authenticate(cx))
197    }
198
199    fn configuration_view(
200        &self,
201        _target_agent: language_model::ConfigurationViewTargetAgent,
202        window: &mut Window,
203        cx: &mut App,
204    ) -> AnyView {
205        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
206            .into()
207    }
208
209    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
210        self.state
211            .update(cx, |state, cx| state.set_api_key(None, cx))
212    }
213}
214
215pub struct OpenAiLanguageModel {
216    id: LanguageModelId,
217    model: open_ai::Model,
218    state: gpui::Entity<State>,
219    http_client: Arc<dyn HttpClient>,
220    request_limiter: RateLimiter,
221}
222
223impl OpenAiLanguageModel {
224    fn stream_completion(
225        &self,
226        request: open_ai::Request,
227        cx: &AsyncApp,
228    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
229    {
230        let http_client = self.http_client.clone();
231
232        let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
233            let api_url = OpenAiLanguageModelProvider::api_url(cx);
234            (state.api_key_state.key(&api_url), api_url)
235        }) else {
236            return future::ready(Err(anyhow!("App state dropped"))).boxed();
237        };
238
239        let future = self.request_limiter.stream(async move {
240            let Some(api_key) = api_key else {
241                return Err(LanguageModelCompletionError::NoApiKey {
242                    provider: PROVIDER_NAME,
243                });
244            };
245            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
246            let response = request.await?;
247            Ok(response)
248        });
249
250        async move { Ok(future.await?.boxed()) }.boxed()
251    }
252}
253
254impl LanguageModel for OpenAiLanguageModel {
255    fn id(&self) -> LanguageModelId {
256        self.id.clone()
257    }
258
259    fn name(&self) -> LanguageModelName {
260        LanguageModelName::from(self.model.display_name().to_string())
261    }
262
263    fn provider_id(&self) -> LanguageModelProviderId {
264        PROVIDER_ID
265    }
266
267    fn provider_name(&self) -> LanguageModelProviderName {
268        PROVIDER_NAME
269    }
270
271    fn supports_tools(&self) -> bool {
272        true
273    }
274
275    fn supports_images(&self) -> bool {
276        use open_ai::Model;
277        match &self.model {
278            Model::FourOmni
279            | Model::FourOmniMini
280            | Model::FourPointOne
281            | Model::FourPointOneMini
282            | Model::FourPointOneNano
283            | Model::Five
284            | Model::FiveMini
285            | Model::FiveNano
286            | Model::O1
287            | Model::O3
288            | Model::O4Mini => true,
289            Model::ThreePointFiveTurbo
290            | Model::Four
291            | Model::FourTurbo
292            | Model::O3Mini
293            | Model::Custom { .. } => false,
294        }
295    }
296
297    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
298        match choice {
299            LanguageModelToolChoice::Auto => true,
300            LanguageModelToolChoice::Any => true,
301            LanguageModelToolChoice::None => true,
302        }
303    }
304
305    fn telemetry_id(&self) -> String {
306        format!("openai/{}", self.model.id())
307    }
308
309    fn max_token_count(&self) -> u64 {
310        self.model.max_token_count()
311    }
312
313    fn max_output_tokens(&self) -> Option<u64> {
314        self.model.max_output_tokens()
315    }
316
317    fn count_tokens(
318        &self,
319        request: LanguageModelRequest,
320        cx: &App,
321    ) -> BoxFuture<'static, Result<u64>> {
322        count_open_ai_tokens(request, self.model.clone(), cx)
323    }
324
325    fn stream_completion(
326        &self,
327        request: LanguageModelRequest,
328        cx: &AsyncApp,
329    ) -> BoxFuture<
330        'static,
331        Result<
332            futures::stream::BoxStream<
333                'static,
334                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
335            >,
336            LanguageModelCompletionError,
337        >,
338    > {
339        let request = into_open_ai(
340            request,
341            self.model.id(),
342            self.model.supports_parallel_tool_calls(),
343            self.model.supports_prompt_cache_key(),
344            self.max_output_tokens(),
345            self.model.reasoning_effort(),
346        );
347        let completions = self.stream_completion(request, cx);
348        async move {
349            let mapper = OpenAiEventMapper::new();
350            Ok(mapper.map_stream(completions.await?).boxed())
351        }
352        .boxed()
353    }
354}
355
356pub fn into_open_ai(
357    request: LanguageModelRequest,
358    model_id: &str,
359    supports_parallel_tool_calls: bool,
360    supports_prompt_cache_key: bool,
361    max_output_tokens: Option<u64>,
362    reasoning_effort: Option<ReasoningEffort>,
363) -> open_ai::Request {
364    let stream = !model_id.starts_with("o1-");
365
366    let mut messages = Vec::new();
367    for message in request.messages {
368        for content in message.content {
369            match content {
370                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
371                    add_message_content_part(
372                        open_ai::MessagePart::Text { text },
373                        message.role,
374                        &mut messages,
375                    )
376                }
377                MessageContent::RedactedThinking(_) => {}
378                MessageContent::Image(image) => {
379                    add_message_content_part(
380                        open_ai::MessagePart::Image {
381                            image_url: ImageUrl {
382                                url: image.to_base64_url(),
383                                detail: None,
384                            },
385                        },
386                        message.role,
387                        &mut messages,
388                    );
389                }
390                MessageContent::ToolUse(tool_use) => {
391                    let tool_call = open_ai::ToolCall {
392                        id: tool_use.id.to_string(),
393                        content: open_ai::ToolCallContent::Function {
394                            function: open_ai::FunctionContent {
395                                name: tool_use.name.to_string(),
396                                arguments: serde_json::to_string(&tool_use.input)
397                                    .unwrap_or_default(),
398                            },
399                        },
400                    };
401
402                    if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
403                        messages.last_mut()
404                    {
405                        tool_calls.push(tool_call);
406                    } else {
407                        messages.push(open_ai::RequestMessage::Assistant {
408                            content: None,
409                            tool_calls: vec![tool_call],
410                        });
411                    }
412                }
413                MessageContent::ToolResult(tool_result) => {
414                    let content = match &tool_result.content {
415                        LanguageModelToolResultContent::Text(text) => {
416                            vec![open_ai::MessagePart::Text {
417                                text: text.to_string(),
418                            }]
419                        }
420                        LanguageModelToolResultContent::Image(image) => {
421                            vec![open_ai::MessagePart::Image {
422                                image_url: ImageUrl {
423                                    url: image.to_base64_url(),
424                                    detail: None,
425                                },
426                            }]
427                        }
428                    };
429
430                    messages.push(open_ai::RequestMessage::Tool {
431                        content: content.into(),
432                        tool_call_id: tool_result.tool_use_id.to_string(),
433                    });
434                }
435            }
436        }
437    }
438
439    open_ai::Request {
440        model: model_id.into(),
441        messages,
442        stream,
443        stop: request.stop,
444        temperature: request.temperature.unwrap_or(1.0),
445        max_completion_tokens: max_output_tokens,
446        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
447            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
448            Some(false)
449        } else {
450            None
451        },
452        prompt_cache_key: if supports_prompt_cache_key {
453            request.thread_id
454        } else {
455            None
456        },
457        tools: request
458            .tools
459            .into_iter()
460            .map(|tool| open_ai::ToolDefinition::Function {
461                function: open_ai::FunctionDefinition {
462                    name: tool.name,
463                    description: Some(tool.description),
464                    parameters: Some(tool.input_schema),
465                },
466            })
467            .collect(),
468        tool_choice: request.tool_choice.map(|choice| match choice {
469            LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
470            LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
471            LanguageModelToolChoice::None => open_ai::ToolChoice::None,
472        }),
473        reasoning_effort,
474    }
475}
476
477fn add_message_content_part(
478    new_part: open_ai::MessagePart,
479    role: Role,
480    messages: &mut Vec<open_ai::RequestMessage>,
481) {
482    match (role, messages.last_mut()) {
483        (Role::User, Some(open_ai::RequestMessage::User { content }))
484        | (
485            Role::Assistant,
486            Some(open_ai::RequestMessage::Assistant {
487                content: Some(content),
488                ..
489            }),
490        )
491        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
492            content.push_part(new_part);
493        }
494        _ => {
495            messages.push(match role {
496                Role::User => open_ai::RequestMessage::User {
497                    content: open_ai::MessageContent::from(vec![new_part]),
498                },
499                Role::Assistant => open_ai::RequestMessage::Assistant {
500                    content: Some(open_ai::MessageContent::from(vec![new_part])),
501                    tool_calls: Vec::new(),
502                },
503                Role::System => open_ai::RequestMessage::System {
504                    content: open_ai::MessageContent::from(vec![new_part]),
505                },
506            });
507        }
508    }
509}
510
511pub struct OpenAiEventMapper {
512    tool_calls_by_index: HashMap<usize, RawToolCall>,
513}
514
515impl OpenAiEventMapper {
516    pub fn new() -> Self {
517        Self {
518            tool_calls_by_index: HashMap::default(),
519        }
520    }
521
522    pub fn map_stream(
523        mut self,
524        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
525    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
526    {
527        events.flat_map(move |event| {
528            futures::stream::iter(match event {
529                Ok(event) => self.map_event(event),
530                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
531            })
532        })
533    }
534
535    pub fn map_event(
536        &mut self,
537        event: ResponseStreamEvent,
538    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
539        let mut events = Vec::new();
540        if let Some(usage) = event.usage {
541            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
542                input_tokens: usage.prompt_tokens,
543                output_tokens: usage.completion_tokens,
544                cache_creation_input_tokens: 0,
545                cache_read_input_tokens: 0,
546            })));
547        }
548
549        let Some(choice) = event.choices.first() else {
550            return events;
551        };
552
553        if let Some(content) = choice.delta.content.clone() {
554            if !content.is_empty() {
555                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
556            }
557        }
558
559        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
560            for tool_call in tool_calls {
561                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
562
563                if let Some(tool_id) = tool_call.id.clone() {
564                    entry.id = tool_id;
565                }
566
567                if let Some(function) = tool_call.function.as_ref() {
568                    if let Some(name) = function.name.clone() {
569                        entry.name = name;
570                    }
571
572                    if let Some(arguments) = function.arguments.clone() {
573                        entry.arguments.push_str(&arguments);
574                    }
575                }
576            }
577        }
578
579        match choice.finish_reason.as_deref() {
580            Some("stop") => {
581                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
582            }
583            Some("tool_calls") => {
584                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
585                    match serde_json::Value::from_str(&tool_call.arguments) {
586                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
587                            LanguageModelToolUse {
588                                id: tool_call.id.clone().into(),
589                                name: tool_call.name.as_str().into(),
590                                is_input_complete: true,
591                                input,
592                                raw_input: tool_call.arguments.clone(),
593                            },
594                        )),
595                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
596                            id: tool_call.id.into(),
597                            tool_name: tool_call.name.into(),
598                            raw_input: tool_call.arguments.clone().into(),
599                            json_parse_error: error.to_string(),
600                        }),
601                    }
602                }));
603
604                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
605            }
606            Some(stop_reason) => {
607                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
608                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
609            }
610            None => {}
611        }
612
613        events
614    }
615}
616
617#[derive(Default)]
618struct RawToolCall {
619    id: String,
620    name: String,
621    arguments: String,
622}
623
624pub(crate) fn collect_tiktoken_messages(
625    request: LanguageModelRequest,
626) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
627    request
628        .messages
629        .into_iter()
630        .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
631            role: match message.role {
632                Role::User => "user".into(),
633                Role::Assistant => "assistant".into(),
634                Role::System => "system".into(),
635            },
636            content: Some(message.string_contents()),
637            name: None,
638            function_call: None,
639        })
640        .collect::<Vec<_>>()
641}
642
643pub fn count_open_ai_tokens(
644    request: LanguageModelRequest,
645    model: Model,
646    cx: &App,
647) -> BoxFuture<'static, Result<u64>> {
648    cx.background_spawn(async move {
649        let messages = collect_tiktoken_messages(request);
650
651        match model {
652            Model::Custom { max_tokens, .. } => {
653                let model = if max_tokens >= 100_000 {
654                    // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
655                    "gpt-4o"
656                } else {
657                    // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
658                    // supported with this tiktoken method
659                    "gpt-4"
660                };
661                tiktoken_rs::num_tokens_from_messages(model, &messages)
662            }
663            // Currently supported by tiktoken_rs
664            // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
665            // arm with an override. We enumerate all supported models here so that we can check if new
666            // models are supported yet or not.
667            Model::ThreePointFiveTurbo
668            | Model::Four
669            | Model::FourTurbo
670            | Model::FourOmni
671            | Model::FourOmniMini
672            | Model::FourPointOne
673            | Model::FourPointOneMini
674            | Model::FourPointOneNano
675            | Model::O1
676            | Model::O3
677            | Model::O3Mini
678            | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
679            // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
680            Model::Five | Model::FiveMini | Model::FiveNano => {
681                tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
682            }
683        }
684        .map(|tokens| tokens as u64)
685    })
686    .boxed()
687}
688
689struct ConfigurationView {
690    api_key_editor: Entity<SingleLineInput>,
691    state: gpui::Entity<State>,
692    load_credentials_task: Option<Task<()>>,
693}
694
695impl ConfigurationView {
696    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
697        let api_key_editor = cx.new(|cx| {
698            SingleLineInput::new(
699                window,
700                cx,
701                "sk-000000000000000000000000000000000000000000000000",
702            )
703        });
704
705        cx.observe(&state, |_, _, cx| {
706            cx.notify();
707        })
708        .detach();
709
710        let load_credentials_task = Some(cx.spawn_in(window, {
711            let state = state.clone();
712            async move |this, cx| {
713                if let Some(task) = state
714                    .update(cx, |state, cx| state.authenticate(cx))
715                    .log_err()
716                {
717                    // We don't log an error, because "not signed in" is also an error.
718                    let _ = task.await;
719                }
720                this.update(cx, |this, cx| {
721                    this.load_credentials_task = None;
722                    cx.notify();
723                })
724                .log_err();
725            }
726        }));
727
728        Self {
729            api_key_editor,
730            state,
731            load_credentials_task,
732        }
733    }
734
735    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
736        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
737        if api_key.is_empty() {
738            return;
739        }
740
741        // url changes can cause the editor to be displayed again
742        self.api_key_editor
743            .update(cx, |editor, cx| editor.set_text("", window, cx));
744
745        let state = self.state.clone();
746        cx.spawn_in(window, async move |_, cx| {
747            state
748                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))?
749                .await
750        })
751        .detach_and_log_err(cx);
752    }
753
754    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
755        self.api_key_editor
756            .update(cx, |input, cx| input.set_text("", window, cx));
757
758        let state = self.state.clone();
759        cx.spawn_in(window, async move |_, cx| {
760            state
761                .update(cx, |state, cx| state.set_api_key(None, cx))?
762                .await
763        })
764        .detach_and_log_err(cx);
765    }
766
767    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
768        !self.state.read(cx).is_authenticated()
769    }
770}
771
772impl Render for ConfigurationView {
773    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
774        let env_var_set = self.state.read(cx).api_key_state.is_from_env_var();
775
776        let api_key_section = if self.should_render_editor(cx) {
777            v_flex()
778                .on_action(cx.listener(Self::save_api_key))
779                .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
780                .child(
781                    List::new()
782                        .child(InstructionListItem::new(
783                            "Create one by visiting",
784                            Some("OpenAI's console"),
785                            Some("https://platform.openai.com/api-keys"),
786                        ))
787                        .child(InstructionListItem::text_only(
788                            "Ensure your OpenAI account has credits",
789                        ))
790                        .child(InstructionListItem::text_only(
791                            "Paste your API key below and hit enter to start using the assistant",
792                        )),
793                )
794                .child(self.api_key_editor.clone())
795                .child(
796                    Label::new(format!(
797                        "You can also assign the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed."
798                    ))
799                    .size(LabelSize::Small)
800                    .color(Color::Muted),
801                )
802                .child(
803                    Label::new(
804                        "Note that having a subscription for another service like GitHub Copilot won't work.",
805                    )
806                    .size(LabelSize::Small).color(Color::Muted),
807                )
808                .into_any()
809        } else {
810            h_flex()
811                .mt_1()
812                .p_1()
813                .justify_between()
814                .rounded_md()
815                .border_1()
816                .border_color(cx.theme().colors().border)
817                .bg(cx.theme().colors().background)
818                .child(
819                    h_flex()
820                        .gap_1()
821                        .child(Icon::new(IconName::Check).color(Color::Success))
822                        .child(Label::new(if env_var_set {
823                            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable")
824                        } else {
825                            let api_url = OpenAiLanguageModelProvider::api_url(cx);
826                            if api_url == OPEN_AI_API_URL {
827                                "API key configured".to_string()
828                            } else {
829                                format!("API key configured for {}", truncate_and_trailoff(&api_url, 32))
830                            }
831                        })),
832                )
833                .child(
834                    Button::new("reset-api-key", "Reset API Key")
835                        .label_size(LabelSize::Small)
836                        .icon(IconName::Undo)
837                        .icon_size(IconSize::Small)
838                        .icon_position(IconPosition::Start)
839                        .layer(ElevationIndex::ModalSurface)
840                        .when(env_var_set, |this| {
841                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable.")))
842                        })
843                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
844                )
845                .into_any()
846        };
847
848        let compatible_api_section = h_flex()
849            .mt_1p5()
850            .gap_0p5()
851            .flex_wrap()
852            .when(self.should_render_editor(cx), |this| {
853                this.pt_1p5()
854                    .border_t_1()
855                    .border_color(cx.theme().colors().border_variant)
856            })
857            .child(
858                h_flex()
859                    .gap_2()
860                    .child(
861                        Icon::new(IconName::Info)
862                            .size(IconSize::XSmall)
863                            .color(Color::Muted),
864                    )
865                    .child(Label::new("Zed also supports OpenAI-compatible models.")),
866            )
867            .child(
868                Button::new("docs", "Learn More")
869                    .icon(IconName::ArrowUpRight)
870                    .icon_size(IconSize::Small)
871                    .icon_color(Color::Muted)
872                    .on_click(move |_, _window, cx| {
873                        cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
874                    }),
875            );
876
877        if self.load_credentials_task.is_some() {
878            div().child(Label::new("Loading credentials…")).into_any()
879        } else {
880            v_flex()
881                .size_full()
882                .child(api_key_section)
883                .child(compatible_api_section)
884                .into_any()
885        }
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use gpui::TestAppContext;
892    use language_model::LanguageModelRequestMessage;
893
894    use super::*;
895
896    #[gpui::test]
897    fn tiktoken_rs_support(cx: &TestAppContext) {
898        let request = LanguageModelRequest {
899            thread_id: None,
900            prompt_id: None,
901            intent: None,
902            mode: None,
903            messages: vec![LanguageModelRequestMessage {
904                role: Role::User,
905                content: vec![MessageContent::Text("message".into())],
906                cache: false,
907            }],
908            tools: vec![],
909            tool_choice: None,
910            stop: vec![],
911            temperature: None,
912            thinking_allowed: true,
913        };
914
915        // Validate that all models are supported by tiktoken-rs
916        for model in Model::iter() {
917            let count = cx
918                .executor()
919                .block(count_open_ai_tokens(
920                    request.clone(),
921                    model,
922                    &cx.app.borrow(),
923                ))
924                .unwrap();
925            assert!(count > 0);
926        }
927    }
928}