open_ai.rs

  1use anyhow::{Context as _, Result, anyhow};
  2use collections::{BTreeMap, HashMap};
  3use credentials_provider::CredentialsProvider;
  4
  5use futures::Stream;
  6use futures::{FutureExt, StreamExt, future::BoxFuture};
  7use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
  8use http_client::HttpClient;
  9use language_model::{
 10    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
 11    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
 12    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 13    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
 14    RateLimiter, Role, StopReason, TokenUsage,
 15};
 16use menu;
 17use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
 18use settings::OpenAiAvailableModel as AvailableModel;
 19use settings::{Settings, SettingsStore};
 20use std::pin::Pin;
 21use std::str::FromStr as _;
 22use std::sync::Arc;
 23use strum::IntoEnumIterator;
 24
 25use ui::{ElevationIndex, List, Tooltip, prelude::*};
 26use ui_input::SingleLineInput;
 27use util::ResultExt;
 28
 29use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 30
 31const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
 32const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
 33
 34#[derive(Default, Clone, Debug, PartialEq)]
 35pub struct OpenAiSettings {
 36    pub api_url: String,
 37    pub available_models: Vec<AvailableModel>,
 38}
 39
 40pub struct OpenAiLanguageModelProvider {
 41    http_client: Arc<dyn HttpClient>,
 42    state: gpui::Entity<State>,
 43}
 44
 45pub struct State {
 46    api_key: Option<String>,
 47    api_key_from_env: bool,
 48    last_api_url: String,
 49    _subscription: Subscription,
 50}
 51
 52const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
 53
 54impl State {
 55    fn is_authenticated(&self) -> bool {
 56        self.api_key.is_some()
 57    }
 58
 59    fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
 60        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 61        let api_url = AllLanguageModelSettings::get_global(cx)
 62            .openai
 63            .api_url
 64            .clone();
 65        cx.spawn(async move |this, cx| {
 66            credentials_provider
 67                .delete_credentials(&api_url, cx)
 68                .await
 69                .log_err();
 70            this.update(cx, |this, cx| {
 71                this.api_key = None;
 72                this.api_key_from_env = false;
 73                cx.notify();
 74            })
 75        })
 76    }
 77
 78    fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
 79        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 80        let api_url = AllLanguageModelSettings::get_global(cx)
 81            .openai
 82            .api_url
 83            .clone();
 84        cx.spawn(async move |this, cx| {
 85            credentials_provider
 86                .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx)
 87                .await
 88                .log_err();
 89            this.update(cx, |this, cx| {
 90                this.api_key = Some(api_key);
 91                cx.notify();
 92            })
 93        })
 94    }
 95
 96    fn get_api_key(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 97        let credentials_provider = <dyn CredentialsProvider>::global(cx);
 98        let api_url = AllLanguageModelSettings::get_global(cx)
 99            .openai
100            .api_url
101            .clone();
102        cx.spawn(async move |this, cx| {
103            let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
104                (api_key, true)
105            } else {
106                let (_, api_key) = credentials_provider
107                    .read_credentials(&api_url, cx)
108                    .await?
109                    .ok_or(AuthenticateError::CredentialsNotFound)?;
110                (
111                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
112                    false,
113                )
114            };
115            this.update(cx, |this, cx| {
116                this.api_key = Some(api_key);
117                this.api_key_from_env = from_env;
118                cx.notify();
119            })?;
120
121            Ok(())
122        })
123    }
124
125    fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
126        if self.is_authenticated() {
127            return Task::ready(Ok(()));
128        }
129
130        self.get_api_key(cx)
131    }
132}
133
134impl OpenAiLanguageModelProvider {
135    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
136        let initial_api_url = AllLanguageModelSettings::get_global(cx)
137            .openai
138            .api_url
139            .clone();
140
141        let state = cx.new(|cx| State {
142            api_key: None,
143            api_key_from_env: false,
144            last_api_url: initial_api_url.clone(),
145            _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
146                let current_api_url = AllLanguageModelSettings::get_global(cx)
147                    .openai
148                    .api_url
149                    .clone();
150
151                if this.last_api_url != current_api_url {
152                    this.last_api_url = current_api_url;
153                    if !this.api_key_from_env {
154                        this.api_key = None;
155                        let spawn_task = cx.spawn(async move |handle, cx| {
156                            if let Ok(task) = handle.update(cx, |this, cx| this.get_api_key(cx)) {
157                                if let Err(_) = task.await {
158                                    handle
159                                        .update(cx, |this, _| {
160                                            this.api_key = None;
161                                            this.api_key_from_env = false;
162                                        })
163                                        .ok();
164                                }
165                            }
166                        });
167                        spawn_task.detach();
168                    }
169                }
170                cx.notify();
171            }),
172        });
173
174        Self { http_client, state }
175    }
176
177    fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
178        Arc::new(OpenAiLanguageModel {
179            id: LanguageModelId::from(model.id().to_string()),
180            model,
181            state: self.state.clone(),
182            http_client: self.http_client.clone(),
183            request_limiter: RateLimiter::new(4),
184        })
185    }
186}
187
188impl LanguageModelProviderState for OpenAiLanguageModelProvider {
189    type ObservableEntity = State;
190
191    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
192        Some(self.state.clone())
193    }
194}
195
196impl LanguageModelProvider for OpenAiLanguageModelProvider {
197    fn id(&self) -> LanguageModelProviderId {
198        PROVIDER_ID
199    }
200
201    fn name(&self) -> LanguageModelProviderName {
202        PROVIDER_NAME
203    }
204
205    fn icon(&self) -> IconName {
206        IconName::AiOpenAi
207    }
208
209    fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
210        Some(self.create_language_model(open_ai::Model::default()))
211    }
212
213    fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
214        Some(self.create_language_model(open_ai::Model::default_fast()))
215    }
216
217    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
218        let mut models = BTreeMap::default();
219
220        // Add base models from open_ai::Model::iter()
221        for model in open_ai::Model::iter() {
222            if !matches!(model, open_ai::Model::Custom { .. }) {
223                models.insert(model.id().to_string(), model);
224            }
225        }
226
227        // Override with available models from settings
228        for model in &AllLanguageModelSettings::get_global(cx)
229            .openai
230            .available_models
231        {
232            models.insert(
233                model.name.clone(),
234                open_ai::Model::Custom {
235                    name: model.name.clone(),
236                    display_name: model.display_name.clone(),
237                    max_tokens: model.max_tokens,
238                    max_output_tokens: model.max_output_tokens,
239                    max_completion_tokens: model.max_completion_tokens,
240                    reasoning_effort: model.reasoning_effort.clone(),
241                },
242            );
243        }
244
245        models
246            .into_values()
247            .map(|model| self.create_language_model(model))
248            .collect()
249    }
250
251    fn is_authenticated(&self, cx: &App) -> bool {
252        self.state.read(cx).is_authenticated()
253    }
254
255    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
256        self.state.update(cx, |state, cx| state.authenticate(cx))
257    }
258
259    fn configuration_view(
260        &self,
261        _target_agent: language_model::ConfigurationViewTargetAgent,
262        window: &mut Window,
263        cx: &mut App,
264    ) -> AnyView {
265        cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
266            .into()
267    }
268
269    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
270        self.state.update(cx, |state, cx| state.reset_api_key(cx))
271    }
272}
273
274pub struct OpenAiLanguageModel {
275    id: LanguageModelId,
276    model: open_ai::Model,
277    state: gpui::Entity<State>,
278    http_client: Arc<dyn HttpClient>,
279    request_limiter: RateLimiter,
280}
281
282impl OpenAiLanguageModel {
283    fn stream_completion(
284        &self,
285        request: open_ai::Request,
286        cx: &AsyncApp,
287    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
288    {
289        let http_client = self.http_client.clone();
290        let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
291            let settings = &AllLanguageModelSettings::get_global(cx).openai;
292            (state.api_key.clone(), settings.api_url.clone())
293        }) else {
294            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
295        };
296
297        let future = self.request_limiter.stream(async move {
298            let Some(api_key) = api_key else {
299                return Err(LanguageModelCompletionError::NoApiKey {
300                    provider: PROVIDER_NAME,
301                });
302            };
303            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
304            let response = request.await?;
305            Ok(response)
306        });
307
308        async move { Ok(future.await?.boxed()) }.boxed()
309    }
310}
311
312impl LanguageModel for OpenAiLanguageModel {
313    fn id(&self) -> LanguageModelId {
314        self.id.clone()
315    }
316
317    fn name(&self) -> LanguageModelName {
318        LanguageModelName::from(self.model.display_name().to_string())
319    }
320
321    fn provider_id(&self) -> LanguageModelProviderId {
322        PROVIDER_ID
323    }
324
325    fn provider_name(&self) -> LanguageModelProviderName {
326        PROVIDER_NAME
327    }
328
329    fn supports_tools(&self) -> bool {
330        true
331    }
332
333    fn supports_images(&self) -> bool {
334        use open_ai::Model;
335        match &self.model {
336            Model::FourOmni
337            | Model::FourOmniMini
338            | Model::FourPointOne
339            | Model::FourPointOneMini
340            | Model::FourPointOneNano
341            | Model::Five
342            | Model::FiveMini
343            | Model::FiveNano
344            | Model::O1
345            | Model::O3
346            | Model::O4Mini => true,
347            Model::ThreePointFiveTurbo
348            | Model::Four
349            | Model::FourTurbo
350            | Model::O3Mini
351            | Model::Custom { .. } => false,
352        }
353    }
354
355    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
356        match choice {
357            LanguageModelToolChoice::Auto => true,
358            LanguageModelToolChoice::Any => true,
359            LanguageModelToolChoice::None => true,
360        }
361    }
362
363    fn telemetry_id(&self) -> String {
364        format!("openai/{}", self.model.id())
365    }
366
367    fn max_token_count(&self) -> u64 {
368        self.model.max_token_count()
369    }
370
371    fn max_output_tokens(&self) -> Option<u64> {
372        self.model.max_output_tokens()
373    }
374
375    fn count_tokens(
376        &self,
377        request: LanguageModelRequest,
378        cx: &App,
379    ) -> BoxFuture<'static, Result<u64>> {
380        count_open_ai_tokens(request, self.model.clone(), cx)
381    }
382
383    fn stream_completion(
384        &self,
385        request: LanguageModelRequest,
386        cx: &AsyncApp,
387    ) -> BoxFuture<
388        'static,
389        Result<
390            futures::stream::BoxStream<
391                'static,
392                Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
393            >,
394            LanguageModelCompletionError,
395        >,
396    > {
397        let request = into_open_ai(
398            request,
399            self.model.id(),
400            self.model.supports_parallel_tool_calls(),
401            self.model.supports_prompt_cache_key(),
402            self.max_output_tokens(),
403            self.model.reasoning_effort(),
404        );
405        let completions = self.stream_completion(request, cx);
406        async move {
407            let mapper = OpenAiEventMapper::new();
408            Ok(mapper.map_stream(completions.await?).boxed())
409        }
410        .boxed()
411    }
412}
413
414pub fn into_open_ai(
415    request: LanguageModelRequest,
416    model_id: &str,
417    supports_parallel_tool_calls: bool,
418    supports_prompt_cache_key: bool,
419    max_output_tokens: Option<u64>,
420    reasoning_effort: Option<ReasoningEffort>,
421) -> open_ai::Request {
422    let stream = !model_id.starts_with("o1-");
423
424    let mut messages = Vec::new();
425    for message in request.messages {
426        for content in message.content {
427            match content {
428                MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
429                    add_message_content_part(
430                        open_ai::MessagePart::Text { text },
431                        message.role,
432                        &mut messages,
433                    )
434                }
435                MessageContent::RedactedThinking(_) => {}
436                MessageContent::Image(image) => {
437                    add_message_content_part(
438                        open_ai::MessagePart::Image {
439                            image_url: ImageUrl {
440                                url: image.to_base64_url(),
441                                detail: None,
442                            },
443                        },
444                        message.role,
445                        &mut messages,
446                    );
447                }
448                MessageContent::ToolUse(tool_use) => {
449                    let tool_call = open_ai::ToolCall {
450                        id: tool_use.id.to_string(),
451                        content: open_ai::ToolCallContent::Function {
452                            function: open_ai::FunctionContent {
453                                name: tool_use.name.to_string(),
454                                arguments: serde_json::to_string(&tool_use.input)
455                                    .unwrap_or_default(),
456                            },
457                        },
458                    };
459
460                    if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
461                        messages.last_mut()
462                    {
463                        tool_calls.push(tool_call);
464                    } else {
465                        messages.push(open_ai::RequestMessage::Assistant {
466                            content: None,
467                            tool_calls: vec![tool_call],
468                        });
469                    }
470                }
471                MessageContent::ToolResult(tool_result) => {
472                    let content = match &tool_result.content {
473                        LanguageModelToolResultContent::Text(text) => {
474                            vec![open_ai::MessagePart::Text {
475                                text: text.to_string(),
476                            }]
477                        }
478                        LanguageModelToolResultContent::Image(image) => {
479                            vec![open_ai::MessagePart::Image {
480                                image_url: ImageUrl {
481                                    url: image.to_base64_url(),
482                                    detail: None,
483                                },
484                            }]
485                        }
486                    };
487
488                    messages.push(open_ai::RequestMessage::Tool {
489                        content: content.into(),
490                        tool_call_id: tool_result.tool_use_id.to_string(),
491                    });
492                }
493            }
494        }
495    }
496
497    open_ai::Request {
498        model: model_id.into(),
499        messages,
500        stream,
501        stop: request.stop,
502        temperature: request.temperature.unwrap_or(1.0),
503        max_completion_tokens: max_output_tokens,
504        parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
505            // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
506            Some(false)
507        } else {
508            None
509        },
510        prompt_cache_key: if supports_prompt_cache_key {
511            request.thread_id
512        } else {
513            None
514        },
515        tools: request
516            .tools
517            .into_iter()
518            .map(|tool| open_ai::ToolDefinition::Function {
519                function: open_ai::FunctionDefinition {
520                    name: tool.name,
521                    description: Some(tool.description),
522                    parameters: Some(tool.input_schema),
523                },
524            })
525            .collect(),
526        tool_choice: request.tool_choice.map(|choice| match choice {
527            LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
528            LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
529            LanguageModelToolChoice::None => open_ai::ToolChoice::None,
530        }),
531        reasoning_effort,
532    }
533}
534
535fn add_message_content_part(
536    new_part: open_ai::MessagePart,
537    role: Role,
538    messages: &mut Vec<open_ai::RequestMessage>,
539) {
540    match (role, messages.last_mut()) {
541        (Role::User, Some(open_ai::RequestMessage::User { content }))
542        | (
543            Role::Assistant,
544            Some(open_ai::RequestMessage::Assistant {
545                content: Some(content),
546                ..
547            }),
548        )
549        | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
550            content.push_part(new_part);
551        }
552        _ => {
553            messages.push(match role {
554                Role::User => open_ai::RequestMessage::User {
555                    content: open_ai::MessageContent::from(vec![new_part]),
556                },
557                Role::Assistant => open_ai::RequestMessage::Assistant {
558                    content: Some(open_ai::MessageContent::from(vec![new_part])),
559                    tool_calls: Vec::new(),
560                },
561                Role::System => open_ai::RequestMessage::System {
562                    content: open_ai::MessageContent::from(vec![new_part]),
563                },
564            });
565        }
566    }
567}
568
569pub struct OpenAiEventMapper {
570    tool_calls_by_index: HashMap<usize, RawToolCall>,
571}
572
573impl OpenAiEventMapper {
574    pub fn new() -> Self {
575        Self {
576            tool_calls_by_index: HashMap::default(),
577        }
578    }
579
580    pub fn map_stream(
581        mut self,
582        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
583    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
584    {
585        events.flat_map(move |event| {
586            futures::stream::iter(match event {
587                Ok(event) => self.map_event(event),
588                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
589            })
590        })
591    }
592
593    pub fn map_event(
594        &mut self,
595        event: ResponseStreamEvent,
596    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
597        let mut events = Vec::new();
598        if let Some(usage) = event.usage {
599            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
600                input_tokens: usage.prompt_tokens,
601                output_tokens: usage.completion_tokens,
602                cache_creation_input_tokens: 0,
603                cache_read_input_tokens: 0,
604            })));
605        }
606
607        let Some(choice) = event.choices.first() else {
608            return events;
609        };
610
611        if let Some(content) = choice.delta.content.clone() {
612            if !content.is_empty() {
613                events.push(Ok(LanguageModelCompletionEvent::Text(content)));
614            }
615        }
616
617        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
618            for tool_call in tool_calls {
619                let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
620
621                if let Some(tool_id) = tool_call.id.clone() {
622                    entry.id = tool_id;
623                }
624
625                if let Some(function) = tool_call.function.as_ref() {
626                    if let Some(name) = function.name.clone() {
627                        entry.name = name;
628                    }
629
630                    if let Some(arguments) = function.arguments.clone() {
631                        entry.arguments.push_str(&arguments);
632                    }
633                }
634            }
635        }
636
637        match choice.finish_reason.as_deref() {
638            Some("stop") => {
639                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
640            }
641            Some("tool_calls") => {
642                events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
643                    match serde_json::Value::from_str(&tool_call.arguments) {
644                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
645                            LanguageModelToolUse {
646                                id: tool_call.id.clone().into(),
647                                name: tool_call.name.as_str().into(),
648                                is_input_complete: true,
649                                input,
650                                raw_input: tool_call.arguments.clone(),
651                            },
652                        )),
653                        Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
654                            id: tool_call.id.into(),
655                            tool_name: tool_call.name.into(),
656                            raw_input: tool_call.arguments.clone().into(),
657                            json_parse_error: error.to_string(),
658                        }),
659                    }
660                }));
661
662                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
663            }
664            Some(stop_reason) => {
665                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
666                events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
667            }
668            None => {}
669        }
670
671        events
672    }
673}
674
675#[derive(Default)]
676struct RawToolCall {
677    id: String,
678    name: String,
679    arguments: String,
680}
681
682pub(crate) fn collect_tiktoken_messages(
683    request: LanguageModelRequest,
684) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
685    request
686        .messages
687        .into_iter()
688        .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
689            role: match message.role {
690                Role::User => "user".into(),
691                Role::Assistant => "assistant".into(),
692                Role::System => "system".into(),
693            },
694            content: Some(message.string_contents()),
695            name: None,
696            function_call: None,
697        })
698        .collect::<Vec<_>>()
699}
700
701pub fn count_open_ai_tokens(
702    request: LanguageModelRequest,
703    model: Model,
704    cx: &App,
705) -> BoxFuture<'static, Result<u64>> {
706    cx.background_spawn(async move {
707        let messages = collect_tiktoken_messages(request);
708
709        match model {
710            Model::Custom { max_tokens, .. } => {
711                let model = if max_tokens >= 100_000 {
712                    // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
713                    "gpt-4o"
714                } else {
715                    // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
716                    // supported with this tiktoken method
717                    "gpt-4"
718                };
719                tiktoken_rs::num_tokens_from_messages(model, &messages)
720            }
721            // Currently supported by tiktoken_rs
722            // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
723            // arm with an override. We enumerate all supported models here so that we can check if new
724            // models are supported yet or not.
725            Model::ThreePointFiveTurbo
726            | Model::Four
727            | Model::FourTurbo
728            | Model::FourOmni
729            | Model::FourOmniMini
730            | Model::FourPointOne
731            | Model::FourPointOneMini
732            | Model::FourPointOneNano
733            | Model::O1
734            | Model::O3
735            | Model::O3Mini
736            | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
737            // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
738            Model::Five | Model::FiveMini | Model::FiveNano => {
739                tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
740            }
741        }
742        .map(|tokens| tokens as u64)
743    })
744    .boxed()
745}
746
747struct ConfigurationView {
748    api_key_editor: Entity<SingleLineInput>,
749    state: gpui::Entity<State>,
750    load_credentials_task: Option<Task<()>>,
751}
752
753impl ConfigurationView {
754    fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
755        let api_key_editor = cx.new(|cx| {
756            SingleLineInput::new(
757                window,
758                cx,
759                "sk-000000000000000000000000000000000000000000000000",
760            )
761        });
762
763        cx.observe(&state, |_, _, cx| {
764            cx.notify();
765        })
766        .detach();
767
768        let load_credentials_task = Some(cx.spawn_in(window, {
769            let state = state.clone();
770            async move |this, cx| {
771                if let Some(task) = state
772                    .update(cx, |state, cx| state.authenticate(cx))
773                    .log_err()
774                {
775                    // We don't log an error, because "not signed in" is also an error.
776                    let _ = task.await;
777                }
778                this.update(cx, |this, cx| {
779                    this.load_credentials_task = None;
780                    cx.notify();
781                })
782                .log_err();
783            }
784        }));
785
786        Self {
787            api_key_editor,
788            state,
789            load_credentials_task,
790        }
791    }
792
793    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
794        let api_key = self
795            .api_key_editor
796            .read(cx)
797            .editor()
798            .read(cx)
799            .text(cx)
800            .trim()
801            .to_string();
802
803        // Don't proceed if no API key is provided and we're not authenticated
804        if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
805            return;
806        }
807
808        let state = self.state.clone();
809        cx.spawn_in(window, async move |_, cx| {
810            state
811                .update(cx, |state, cx| state.set_api_key(api_key, cx))?
812                .await
813        })
814        .detach_and_log_err(cx);
815
816        cx.notify();
817    }
818
819    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
820        self.api_key_editor.update(cx, |input, cx| {
821            input.editor.update(cx, |editor, cx| {
822                editor.set_text("", window, cx);
823            });
824        });
825
826        let state = self.state.clone();
827        cx.spawn_in(window, async move |_, cx| {
828            state.update(cx, |state, cx| state.reset_api_key(cx))?.await
829        })
830        .detach_and_log_err(cx);
831
832        cx.notify();
833    }
834
835    fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
836        !self.state.read(cx).is_authenticated()
837    }
838}
839
840impl Render for ConfigurationView {
841    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
842        let env_var_set = self.state.read(cx).api_key_from_env;
843
844        let api_key_section = if self.should_render_editor(cx) {
845            v_flex()
846                .on_action(cx.listener(Self::save_api_key))
847                .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
848                .child(
849                    List::new()
850                        .child(InstructionListItem::new(
851                            "Create one by visiting",
852                            Some("OpenAI's console"),
853                            Some("https://platform.openai.com/api-keys"),
854                        ))
855                        .child(InstructionListItem::text_only(
856                            "Ensure your OpenAI account has credits",
857                        ))
858                        .child(InstructionListItem::text_only(
859                            "Paste your API key below and hit enter to start using the assistant",
860                        )),
861                )
862                .child(self.api_key_editor.clone())
863                .child(
864                    Label::new(
865                        format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
866                    )
867                    .size(LabelSize::Small).color(Color::Muted),
868                )
869                .child(
870                    Label::new(
871                        "Note that having a subscription for another service like GitHub Copilot won't work.",
872                    )
873                    .size(LabelSize::Small).color(Color::Muted),
874                )
875                .into_any()
876        } else {
877            h_flex()
878                .mt_1()
879                .p_1()
880                .justify_between()
881                .rounded_md()
882                .border_1()
883                .border_color(cx.theme().colors().border)
884                .bg(cx.theme().colors().background)
885                .child(
886                    h_flex()
887                        .gap_1()
888                        .child(Icon::new(IconName::Check).color(Color::Success))
889                        .child(Label::new(if env_var_set {
890                            format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
891                        } else {
892                            "API key configured.".to_string()
893                        })),
894                )
895                .child(
896                    Button::new("reset-api-key", "Reset API Key")
897                        .label_size(LabelSize::Small)
898                        .icon(IconName::Undo)
899                        .icon_size(IconSize::Small)
900                        .icon_position(IconPosition::Start)
901                        .layer(ElevationIndex::ModalSurface)
902                        .when(env_var_set, |this| {
903                            this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
904                        })
905                        .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
906                )
907                .into_any()
908        };
909
910        let compatible_api_section = h_flex()
911            .mt_1p5()
912            .gap_0p5()
913            .flex_wrap()
914            .when(self.should_render_editor(cx), |this| {
915                this.pt_1p5()
916                    .border_t_1()
917                    .border_color(cx.theme().colors().border_variant)
918            })
919            .child(
920                h_flex()
921                    .gap_2()
922                    .child(
923                        Icon::new(IconName::Info)
924                            .size(IconSize::XSmall)
925                            .color(Color::Muted),
926                    )
927                    .child(Label::new("Zed also supports OpenAI-compatible models.")),
928            )
929            .child(
930                Button::new("docs", "Learn More")
931                    .icon(IconName::ArrowUpRight)
932                    .icon_size(IconSize::Small)
933                    .icon_color(Color::Muted)
934                    .on_click(move |_, _window, cx| {
935                        cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
936                    }),
937            );
938
939        if self.load_credentials_task.is_some() {
940            div().child(Label::new("Loading credentials…")).into_any()
941        } else {
942            v_flex()
943                .size_full()
944                .child(api_key_section)
945                .child(compatible_api_section)
946                .into_any()
947        }
948    }
949}
950
951#[cfg(test)]
952mod tests {
953    use gpui::TestAppContext;
954    use language_model::LanguageModelRequestMessage;
955
956    use super::*;
957
958    #[gpui::test]
959    fn tiktoken_rs_support(cx: &TestAppContext) {
960        let request = LanguageModelRequest {
961            thread_id: None,
962            prompt_id: None,
963            intent: None,
964            mode: None,
965            messages: vec![LanguageModelRequestMessage {
966                role: Role::User,
967                content: vec![MessageContent::Text("message".into())],
968                cache: false,
969            }],
970            tools: vec![],
971            tool_choice: None,
972            stop: vec![],
973            temperature: None,
974            thinking_allowed: true,
975        };
976
977        // Validate that all models are supported by tiktoken-rs
978        for model in Model::iter() {
979            let count = cx
980                .executor()
981                .block(count_open_ai_tokens(
982                    request.clone(),
983                    model,
984                    &cx.app.borrow(),
985                ))
986                .unwrap();
987            assert!(count > 0);
988        }
989    }
990}