1use anyhow::{Context as _, Result, anyhow};
2use collections::{BTreeMap, HashMap};
3use credentials_provider::CredentialsProvider;
4
5use futures::Stream;
6use futures::{FutureExt, StreamExt, future::BoxFuture};
7use gpui::{AnyView, App, AsyncApp, Context, Entity, Subscription, Task, Window};
8use http_client::HttpClient;
9use language_model::{
10 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
11 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
12 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
13 LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
14 RateLimiter, Role, StopReason, TokenUsage,
15};
16use menu;
17use open_ai::{ImageUrl, Model, ReasoningEffort, ResponseStreamEvent, stream_completion};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::pin::Pin;
22use std::str::FromStr as _;
23use std::sync::Arc;
24use strum::IntoEnumIterator;
25
26use ui::{ElevationIndex, List, Tooltip, prelude::*};
27use ui_input::SingleLineInput;
28use util::ResultExt;
29
30use crate::{AllLanguageModelSettings, ui::InstructionListItem};
31
32const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID;
33const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME;
34
35#[derive(Default, Clone, Debug, PartialEq)]
36pub struct OpenAiSettings {
37 pub api_url: String,
38 pub available_models: Vec<AvailableModel>,
39}
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
42pub struct AvailableModel {
43 pub name: String,
44 pub display_name: Option<String>,
45 pub max_tokens: u64,
46 pub max_output_tokens: Option<u64>,
47 pub max_completion_tokens: Option<u64>,
48 pub reasoning_effort: Option<ReasoningEffort>,
49}
50
51pub struct OpenAiLanguageModelProvider {
52 http_client: Arc<dyn HttpClient>,
53 state: gpui::Entity<State>,
54}
55
56pub struct State {
57 api_key: Option<String>,
58 api_key_from_env: bool,
59 _subscription: Subscription,
60}
61
62const OPENAI_API_KEY_VAR: &str = "OPENAI_API_KEY";
63
64impl State {
65 //
66 fn is_authenticated(&self) -> bool {
67 self.api_key.is_some()
68 }
69
70 fn reset_api_key(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
71 let credentials_provider = <dyn CredentialsProvider>::global(cx);
72 let api_url = AllLanguageModelSettings::get_global(cx)
73 .openai
74 .api_url
75 .clone();
76 cx.spawn(async move |this, cx| {
77 credentials_provider
78 .delete_credentials(&api_url, &cx)
79 .await
80 .log_err();
81 this.update(cx, |this, cx| {
82 this.api_key = None;
83 this.api_key_from_env = false;
84 cx.notify();
85 })
86 })
87 }
88
89 fn set_api_key(&mut self, api_key: String, cx: &mut Context<Self>) -> Task<Result<()>> {
90 let credentials_provider = <dyn CredentialsProvider>::global(cx);
91 let api_url = AllLanguageModelSettings::get_global(cx)
92 .openai
93 .api_url
94 .clone();
95 cx.spawn(async move |this, cx| {
96 credentials_provider
97 .write_credentials(&api_url, "Bearer", api_key.as_bytes(), &cx)
98 .await
99 .log_err();
100 this.update(cx, |this, cx| {
101 this.api_key = Some(api_key);
102 cx.notify();
103 })
104 })
105 }
106
107 fn authenticate(&self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
108 if self.is_authenticated() {
109 return Task::ready(Ok(()));
110 }
111
112 let credentials_provider = <dyn CredentialsProvider>::global(cx);
113 let api_url = AllLanguageModelSettings::get_global(cx)
114 .openai
115 .api_url
116 .clone();
117 cx.spawn(async move |this, cx| {
118 let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENAI_API_KEY_VAR) {
119 (api_key, true)
120 } else {
121 let (_, api_key) = credentials_provider
122 .read_credentials(&api_url, &cx)
123 .await?
124 .ok_or(AuthenticateError::CredentialsNotFound)?;
125 (
126 String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
127 false,
128 )
129 };
130 this.update(cx, |this, cx| {
131 this.api_key = Some(api_key);
132 this.api_key_from_env = from_env;
133 cx.notify();
134 })?;
135
136 Ok(())
137 })
138 }
139}
140
141impl OpenAiLanguageModelProvider {
142 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
143 let state = cx.new(|cx| State {
144 api_key: None,
145 api_key_from_env: false,
146 _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
147 cx.notify();
148 }),
149 });
150
151 Self { http_client, state }
152 }
153
154 fn create_language_model(&self, model: open_ai::Model) -> Arc<dyn LanguageModel> {
155 Arc::new(OpenAiLanguageModel {
156 id: LanguageModelId::from(model.id().to_string()),
157 model,
158 state: self.state.clone(),
159 http_client: self.http_client.clone(),
160 request_limiter: RateLimiter::new(4),
161 })
162 }
163}
164
165impl LanguageModelProviderState for OpenAiLanguageModelProvider {
166 type ObservableEntity = State;
167
168 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
169 Some(self.state.clone())
170 }
171}
172
173impl LanguageModelProvider for OpenAiLanguageModelProvider {
174 fn id(&self) -> LanguageModelProviderId {
175 PROVIDER_ID
176 }
177
178 fn name(&self) -> LanguageModelProviderName {
179 PROVIDER_NAME
180 }
181
182 fn icon(&self) -> IconName {
183 IconName::AiOpenAi
184 }
185
186 fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
187 Some(self.create_language_model(open_ai::Model::default()))
188 }
189
190 fn default_fast_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
191 Some(self.create_language_model(open_ai::Model::default_fast()))
192 }
193
194 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
195 let mut models = BTreeMap::default();
196
197 // Add base models from open_ai::Model::iter()
198 for model in open_ai::Model::iter() {
199 if !matches!(model, open_ai::Model::Custom { .. }) {
200 models.insert(model.id().to_string(), model);
201 }
202 }
203
204 // Override with available models from settings
205 for model in &AllLanguageModelSettings::get_global(cx)
206 .openai
207 .available_models
208 {
209 models.insert(
210 model.name.clone(),
211 open_ai::Model::Custom {
212 name: model.name.clone(),
213 display_name: model.display_name.clone(),
214 max_tokens: model.max_tokens,
215 max_output_tokens: model.max_output_tokens,
216 max_completion_tokens: model.max_completion_tokens,
217 reasoning_effort: model.reasoning_effort.clone(),
218 },
219 );
220 }
221
222 models
223 .into_values()
224 .map(|model| self.create_language_model(model))
225 .collect()
226 }
227
228 fn is_authenticated(&self, cx: &App) -> bool {
229 self.state.read(cx).is_authenticated()
230 }
231
232 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
233 self.state.update(cx, |state, cx| state.authenticate(cx))
234 }
235
236 fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
237 cx.new(|cx| ConfigurationView::new(self.state.clone(), window, cx))
238 .into()
239 }
240
241 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
242 self.state.update(cx, |state, cx| state.reset_api_key(cx))
243 }
244}
245
246pub struct OpenAiLanguageModel {
247 id: LanguageModelId,
248 model: open_ai::Model,
249 state: gpui::Entity<State>,
250 http_client: Arc<dyn HttpClient>,
251 request_limiter: RateLimiter,
252}
253
254impl OpenAiLanguageModel {
255 fn stream_completion(
256 &self,
257 request: open_ai::Request,
258 cx: &AsyncApp,
259 ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
260 {
261 let http_client = self.http_client.clone();
262 let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
263 let settings = &AllLanguageModelSettings::get_global(cx).openai;
264 (state.api_key.clone(), settings.api_url.clone())
265 }) else {
266 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
267 };
268
269 let future = self.request_limiter.stream(async move {
270 let Some(api_key) = api_key else {
271 return Err(LanguageModelCompletionError::NoApiKey {
272 provider: PROVIDER_NAME,
273 });
274 };
275 let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
276 let response = request.await?;
277 Ok(response)
278 });
279
280 async move { Ok(future.await?.boxed()) }.boxed()
281 }
282}
283
284impl LanguageModel for OpenAiLanguageModel {
285 fn id(&self) -> LanguageModelId {
286 self.id.clone()
287 }
288
289 fn name(&self) -> LanguageModelName {
290 LanguageModelName::from(self.model.display_name().to_string())
291 }
292
293 fn provider_id(&self) -> LanguageModelProviderId {
294 PROVIDER_ID
295 }
296
297 fn provider_name(&self) -> LanguageModelProviderName {
298 PROVIDER_NAME
299 }
300
301 fn supports_tools(&self) -> bool {
302 true
303 }
304
305 fn supports_images(&self) -> bool {
306 use open_ai::Model;
307 match &self.model {
308 Model::FourOmni
309 | Model::FourOmniMini
310 | Model::FourPointOne
311 | Model::FourPointOneMini
312 | Model::FourPointOneNano
313 | Model::Five
314 | Model::FiveMini
315 | Model::FiveNano
316 | Model::O1
317 | Model::O3
318 | Model::O4Mini => true,
319 Model::ThreePointFiveTurbo
320 | Model::Four
321 | Model::FourTurbo
322 | Model::O3Mini
323 | Model::Custom { .. } => false,
324 }
325 }
326
327 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
328 match choice {
329 LanguageModelToolChoice::Auto => true,
330 LanguageModelToolChoice::Any => true,
331 LanguageModelToolChoice::None => true,
332 }
333 }
334
335 fn telemetry_id(&self) -> String {
336 format!("openai/{}", self.model.id())
337 }
338
339 fn max_token_count(&self) -> u64 {
340 self.model.max_token_count()
341 }
342
343 fn max_output_tokens(&self) -> Option<u64> {
344 self.model.max_output_tokens()
345 }
346
347 fn count_tokens(
348 &self,
349 request: LanguageModelRequest,
350 cx: &App,
351 ) -> BoxFuture<'static, Result<u64>> {
352 count_open_ai_tokens(request, self.model.clone(), cx)
353 }
354
355 fn stream_completion(
356 &self,
357 request: LanguageModelRequest,
358 cx: &AsyncApp,
359 ) -> BoxFuture<
360 'static,
361 Result<
362 futures::stream::BoxStream<
363 'static,
364 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
365 >,
366 LanguageModelCompletionError,
367 >,
368 > {
369 let request = into_open_ai(
370 request,
371 self.model.id(),
372 self.model.supports_parallel_tool_calls(),
373 self.model.supports_prompt_cache_key(),
374 self.max_output_tokens(),
375 self.model.reasoning_effort(),
376 );
377 let completions = self.stream_completion(request, cx);
378 async move {
379 let mapper = OpenAiEventMapper::new();
380 Ok(mapper.map_stream(completions.await?).boxed())
381 }
382 .boxed()
383 }
384}
385
386pub fn into_open_ai(
387 request: LanguageModelRequest,
388 model_id: &str,
389 supports_parallel_tool_calls: bool,
390 supports_prompt_cache_key: bool,
391 max_output_tokens: Option<u64>,
392 reasoning_effort: Option<ReasoningEffort>,
393) -> open_ai::Request {
394 let stream = !model_id.starts_with("o1-");
395
396 let mut messages = Vec::new();
397 for message in request.messages {
398 for content in message.content {
399 match content {
400 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
401 add_message_content_part(
402 open_ai::MessagePart::Text { text: text },
403 message.role,
404 &mut messages,
405 )
406 }
407 MessageContent::RedactedThinking(_) => {}
408 MessageContent::Image(image) => {
409 add_message_content_part(
410 open_ai::MessagePart::Image {
411 image_url: ImageUrl {
412 url: image.to_base64_url(),
413 detail: None,
414 },
415 },
416 message.role,
417 &mut messages,
418 );
419 }
420 MessageContent::ToolUse(tool_use) => {
421 let tool_call = open_ai::ToolCall {
422 id: tool_use.id.to_string(),
423 content: open_ai::ToolCallContent::Function {
424 function: open_ai::FunctionContent {
425 name: tool_use.name.to_string(),
426 arguments: serde_json::to_string(&tool_use.input)
427 .unwrap_or_default(),
428 },
429 },
430 };
431
432 if let Some(open_ai::RequestMessage::Assistant { tool_calls, .. }) =
433 messages.last_mut()
434 {
435 tool_calls.push(tool_call);
436 } else {
437 messages.push(open_ai::RequestMessage::Assistant {
438 content: None,
439 tool_calls: vec![tool_call],
440 });
441 }
442 }
443 MessageContent::ToolResult(tool_result) => {
444 let content = match &tool_result.content {
445 LanguageModelToolResultContent::Text(text) => {
446 vec![open_ai::MessagePart::Text {
447 text: text.to_string(),
448 }]
449 }
450 LanguageModelToolResultContent::Image(image) => {
451 vec![open_ai::MessagePart::Image {
452 image_url: ImageUrl {
453 url: image.to_base64_url(),
454 detail: None,
455 },
456 }]
457 }
458 };
459
460 messages.push(open_ai::RequestMessage::Tool {
461 content: content.into(),
462 tool_call_id: tool_result.tool_use_id.to_string(),
463 });
464 }
465 }
466 }
467 }
468
469 open_ai::Request {
470 model: model_id.into(),
471 messages,
472 stream,
473 stop: request.stop,
474 temperature: request.temperature.unwrap_or(1.0),
475 max_completion_tokens: max_output_tokens,
476 parallel_tool_calls: if supports_parallel_tool_calls && !request.tools.is_empty() {
477 // Disable parallel tool calls, as the Agent currently expects a maximum of one per turn.
478 Some(false)
479 } else {
480 None
481 },
482 prompt_cache_key: if supports_prompt_cache_key {
483 request.thread_id
484 } else {
485 None
486 },
487 tools: request
488 .tools
489 .into_iter()
490 .map(|tool| open_ai::ToolDefinition::Function {
491 function: open_ai::FunctionDefinition {
492 name: tool.name,
493 description: Some(tool.description),
494 parameters: Some(tool.input_schema),
495 },
496 })
497 .collect(),
498 tool_choice: request.tool_choice.map(|choice| match choice {
499 LanguageModelToolChoice::Auto => open_ai::ToolChoice::Auto,
500 LanguageModelToolChoice::Any => open_ai::ToolChoice::Required,
501 LanguageModelToolChoice::None => open_ai::ToolChoice::None,
502 }),
503 reasoning_effort,
504 }
505}
506
507fn add_message_content_part(
508 new_part: open_ai::MessagePart,
509 role: Role,
510 messages: &mut Vec<open_ai::RequestMessage>,
511) {
512 match (role, messages.last_mut()) {
513 (Role::User, Some(open_ai::RequestMessage::User { content }))
514 | (
515 Role::Assistant,
516 Some(open_ai::RequestMessage::Assistant {
517 content: Some(content),
518 ..
519 }),
520 )
521 | (Role::System, Some(open_ai::RequestMessage::System { content, .. })) => {
522 content.push_part(new_part);
523 }
524 _ => {
525 messages.push(match role {
526 Role::User => open_ai::RequestMessage::User {
527 content: open_ai::MessageContent::from(vec![new_part]),
528 },
529 Role::Assistant => open_ai::RequestMessage::Assistant {
530 content: Some(open_ai::MessageContent::from(vec![new_part])),
531 tool_calls: Vec::new(),
532 },
533 Role::System => open_ai::RequestMessage::System {
534 content: open_ai::MessageContent::from(vec![new_part]),
535 },
536 });
537 }
538 }
539}
540
541pub struct OpenAiEventMapper {
542 tool_calls_by_index: HashMap<usize, RawToolCall>,
543}
544
545impl OpenAiEventMapper {
546 pub fn new() -> Self {
547 Self {
548 tool_calls_by_index: HashMap::default(),
549 }
550 }
551
552 pub fn map_stream(
553 mut self,
554 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
555 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
556 {
557 events.flat_map(move |event| {
558 futures::stream::iter(match event {
559 Ok(event) => self.map_event(event),
560 Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
561 })
562 })
563 }
564
565 pub fn map_event(
566 &mut self,
567 event: ResponseStreamEvent,
568 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
569 let mut events = Vec::new();
570 if let Some(usage) = event.usage {
571 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
572 input_tokens: usage.prompt_tokens,
573 output_tokens: usage.completion_tokens,
574 cache_creation_input_tokens: 0,
575 cache_read_input_tokens: 0,
576 })));
577 }
578
579 let Some(choice) = event.choices.first() else {
580 return events;
581 };
582
583 if let Some(content) = choice.delta.content.clone() {
584 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
585 }
586
587 if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
588 for tool_call in tool_calls {
589 let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
590
591 if let Some(tool_id) = tool_call.id.clone() {
592 entry.id = tool_id;
593 }
594
595 if let Some(function) = tool_call.function.as_ref() {
596 if let Some(name) = function.name.clone() {
597 entry.name = name;
598 }
599
600 if let Some(arguments) = function.arguments.clone() {
601 entry.arguments.push_str(&arguments);
602 }
603 }
604 }
605 }
606
607 match choice.finish_reason.as_deref() {
608 Some("stop") => {
609 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
610 }
611 Some("tool_calls") => {
612 events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
613 match serde_json::Value::from_str(&tool_call.arguments) {
614 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
615 LanguageModelToolUse {
616 id: tool_call.id.clone().into(),
617 name: tool_call.name.as_str().into(),
618 is_input_complete: true,
619 input,
620 raw_input: tool_call.arguments.clone(),
621 },
622 )),
623 Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
624 id: tool_call.id.into(),
625 tool_name: tool_call.name.into(),
626 raw_input: tool_call.arguments.clone().into(),
627 json_parse_error: error.to_string(),
628 }),
629 }
630 }));
631
632 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
633 }
634 Some(stop_reason) => {
635 log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
636 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
637 }
638 None => {}
639 }
640
641 events
642 }
643}
644
645#[derive(Default)]
646struct RawToolCall {
647 id: String,
648 name: String,
649 arguments: String,
650}
651
652pub(crate) fn collect_tiktoken_messages(
653 request: LanguageModelRequest,
654) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
655 request
656 .messages
657 .into_iter()
658 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
659 role: match message.role {
660 Role::User => "user".into(),
661 Role::Assistant => "assistant".into(),
662 Role::System => "system".into(),
663 },
664 content: Some(message.string_contents()),
665 name: None,
666 function_call: None,
667 })
668 .collect::<Vec<_>>()
669}
670
671pub fn count_open_ai_tokens(
672 request: LanguageModelRequest,
673 model: Model,
674 cx: &App,
675) -> BoxFuture<'static, Result<u64>> {
676 cx.background_spawn(async move {
677 let messages = collect_tiktoken_messages(request);
678
679 match model {
680 Model::Custom { max_tokens, .. } => {
681 let model = if max_tokens >= 100_000 {
682 // If the max tokens is 100k or more, it is likely the o200k_base tokenizer from gpt4o
683 "gpt-4o"
684 } else {
685 // Otherwise fallback to gpt-4, since only cl100k_base and o200k_base are
686 // supported with this tiktoken method
687 "gpt-4"
688 };
689 tiktoken_rs::num_tokens_from_messages(model, &messages)
690 }
691 // Currently supported by tiktoken_rs
692 // Sometimes tiktoken-rs is behind on model support. If that is the case, make a new branch
693 // arm with an override. We enumerate all supported models here so that we can check if new
694 // models are supported yet or not.
695 Model::ThreePointFiveTurbo
696 | Model::Four
697 | Model::FourTurbo
698 | Model::FourOmni
699 | Model::FourOmniMini
700 | Model::FourPointOne
701 | Model::FourPointOneMini
702 | Model::FourPointOneNano
703 | Model::O1
704 | Model::O3
705 | Model::O3Mini
706 | Model::O4Mini => tiktoken_rs::num_tokens_from_messages(model.id(), &messages),
707 // GPT-5 models don't have tiktoken support yet; fall back on gpt-4o tokenizer
708 Model::Five | Model::FiveMini | Model::FiveNano => {
709 tiktoken_rs::num_tokens_from_messages("gpt-4o", &messages)
710 }
711 }
712 .map(|tokens| tokens as u64)
713 })
714 .boxed()
715}
716
717struct ConfigurationView {
718 api_key_editor: Entity<SingleLineInput>,
719 state: gpui::Entity<State>,
720 load_credentials_task: Option<Task<()>>,
721}
722
723impl ConfigurationView {
724 fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
725 let api_key_editor = cx.new(|cx| {
726 SingleLineInput::new(
727 window,
728 cx,
729 "sk-000000000000000000000000000000000000000000000000",
730 )
731 });
732
733 cx.observe(&state, |_, _, cx| {
734 cx.notify();
735 })
736 .detach();
737
738 let load_credentials_task = Some(cx.spawn_in(window, {
739 let state = state.clone();
740 async move |this, cx| {
741 if let Some(task) = state
742 .update(cx, |state, cx| state.authenticate(cx))
743 .log_err()
744 {
745 // We don't log an error, because "not signed in" is also an error.
746 let _ = task.await;
747 }
748 this.update(cx, |this, cx| {
749 this.load_credentials_task = None;
750 cx.notify();
751 })
752 .log_err();
753 }
754 }));
755
756 Self {
757 api_key_editor,
758 state,
759 load_credentials_task,
760 }
761 }
762
763 fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
764 let api_key = self
765 .api_key_editor
766 .read(cx)
767 .editor()
768 .read(cx)
769 .text(cx)
770 .trim()
771 .to_string();
772
773 // Don't proceed if no API key is provided and we're not authenticated
774 if api_key.is_empty() && !self.state.read(cx).is_authenticated() {
775 return;
776 }
777
778 let state = self.state.clone();
779 cx.spawn_in(window, async move |_, cx| {
780 state
781 .update(cx, |state, cx| state.set_api_key(api_key, cx))?
782 .await
783 })
784 .detach_and_log_err(cx);
785
786 cx.notify();
787 }
788
789 fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
790 self.api_key_editor.update(cx, |input, cx| {
791 input.editor.update(cx, |editor, cx| {
792 editor.set_text("", window, cx);
793 });
794 });
795
796 let state = self.state.clone();
797 cx.spawn_in(window, async move |_, cx| {
798 state.update(cx, |state, cx| state.reset_api_key(cx))?.await
799 })
800 .detach_and_log_err(cx);
801
802 cx.notify();
803 }
804
805 fn should_render_editor(&self, cx: &mut Context<Self>) -> bool {
806 !self.state.read(cx).is_authenticated()
807 }
808}
809
810impl Render for ConfigurationView {
811 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
812 let env_var_set = self.state.read(cx).api_key_from_env;
813
814 let api_key_section = if self.should_render_editor(cx) {
815 v_flex()
816 .on_action(cx.listener(Self::save_api_key))
817 .child(Label::new("To use Zed's agent with OpenAI, you need to add an API key. Follow these steps:"))
818 .child(
819 List::new()
820 .child(InstructionListItem::new(
821 "Create one by visiting",
822 Some("OpenAI's console"),
823 Some("https://platform.openai.com/api-keys"),
824 ))
825 .child(InstructionListItem::text_only(
826 "Ensure your OpenAI account has credits",
827 ))
828 .child(InstructionListItem::text_only(
829 "Paste your API key below and hit enter to start using the assistant",
830 )),
831 )
832 .child(self.api_key_editor.clone())
833 .child(
834 Label::new(
835 format!("You can also assign the {OPENAI_API_KEY_VAR} environment variable and restart Zed."),
836 )
837 .size(LabelSize::Small).color(Color::Muted),
838 )
839 .child(
840 Label::new(
841 "Note that having a subscription for another service like GitHub Copilot won't work.",
842 )
843 .size(LabelSize::Small).color(Color::Muted),
844 )
845 .into_any()
846 } else {
847 h_flex()
848 .mt_1()
849 .p_1()
850 .justify_between()
851 .rounded_md()
852 .border_1()
853 .border_color(cx.theme().colors().border)
854 .bg(cx.theme().colors().background)
855 .child(
856 h_flex()
857 .gap_1()
858 .child(Icon::new(IconName::Check).color(Color::Success))
859 .child(Label::new(if env_var_set {
860 format!("API key set in {OPENAI_API_KEY_VAR} environment variable.")
861 } else {
862 "API key configured.".to_string()
863 })),
864 )
865 .child(
866 Button::new("reset-api-key", "Reset API Key")
867 .label_size(LabelSize::Small)
868 .icon(IconName::Undo)
869 .icon_size(IconSize::Small)
870 .icon_position(IconPosition::Start)
871 .layer(ElevationIndex::ModalSurface)
872 .when(env_var_set, |this| {
873 this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OPENAI_API_KEY_VAR} environment variable.")))
874 })
875 .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))),
876 )
877 .into_any()
878 };
879
880 let compatible_api_section = h_flex()
881 .mt_1p5()
882 .gap_0p5()
883 .flex_wrap()
884 .when(self.should_render_editor(cx), |this| {
885 this.pt_1p5()
886 .border_t_1()
887 .border_color(cx.theme().colors().border_variant)
888 })
889 .child(
890 h_flex()
891 .gap_2()
892 .child(
893 Icon::new(IconName::Info)
894 .size(IconSize::XSmall)
895 .color(Color::Muted),
896 )
897 .child(Label::new("Zed also supports OpenAI-compatible models.")),
898 )
899 .child(
900 Button::new("docs", "Learn More")
901 .icon(IconName::ArrowUpRight)
902 .icon_size(IconSize::Small)
903 .icon_color(Color::Muted)
904 .on_click(move |_, _window, cx| {
905 cx.open_url("https://zed.dev/docs/ai/llm-providers#openai-api-compatible")
906 }),
907 );
908
909 if self.load_credentials_task.is_some() {
910 div().child(Label::new("Loading credentials…")).into_any()
911 } else {
912 v_flex()
913 .size_full()
914 .child(api_key_section)
915 .child(compatible_api_section)
916 .into_any()
917 }
918 }
919}
920
921#[cfg(test)]
922mod tests {
923 use gpui::TestAppContext;
924 use language_model::LanguageModelRequestMessage;
925
926 use super::*;
927
928 #[gpui::test]
929 fn tiktoken_rs_support(cx: &TestAppContext) {
930 let request = LanguageModelRequest {
931 thread_id: None,
932 prompt_id: None,
933 intent: None,
934 mode: None,
935 messages: vec![LanguageModelRequestMessage {
936 role: Role::User,
937 content: vec![MessageContent::Text("message".into())],
938 cache: false,
939 }],
940 tools: vec![],
941 tool_choice: None,
942 stop: vec![],
943 temperature: None,
944 thinking_allowed: true,
945 };
946
947 // Validate that all models are supported by tiktoken-rs
948 for model in Model::iter() {
949 let count = cx
950 .executor()
951 .block(count_open_ai_tokens(
952 request.clone(),
953 model,
954 &cx.app.borrow(),
955 ))
956 .unwrap();
957 assert!(count > 0);
958 }
959 }
960}