1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use cloud_llm_client::CompletionIntent;
7use collections::HashMap;
8use copilot::copilot_chat::{
9 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
10 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
11 ToolCall,
12};
13use copilot::{Copilot, Status};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, Stream, StreamExt};
17use gpui::{Action, AnyView, App, AsyncApp, Entity, Render, Subscription, Task, svg};
18use language::language_settings::all_language_settings;
19use language_model::{
20 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
22 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
23 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
24 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
25 StopReason, TokenUsage,
26};
27use settings::SettingsStore;
28use ui::{CommonAnimationExt, prelude::*};
29use util::debug_panic;
30
31const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
32const PROVIDER_NAME: LanguageModelProviderName =
33 LanguageModelProviderName::new("GitHub Copilot Chat");
34
35pub struct CopilotChatLanguageModelProvider {
36 state: Entity<State>,
37}
38
39pub struct State {
40 _copilot_chat_subscription: Option<Subscription>,
41 _settings_subscription: Subscription,
42}
43
44impl State {
45 fn is_authenticated(&self, cx: &App) -> bool {
46 CopilotChat::global(cx)
47 .map(|m| m.read(cx).is_authenticated())
48 .unwrap_or(false)
49 }
50}
51
52impl CopilotChatLanguageModelProvider {
53 pub fn new(cx: &mut App) -> Self {
54 let state = cx.new(|cx| {
55 let copilot_chat_subscription = CopilotChat::global(cx)
56 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
57 State {
58 _copilot_chat_subscription: copilot_chat_subscription,
59 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
60 if let Some(copilot_chat) = CopilotChat::global(cx) {
61 let language_settings = all_language_settings(None, cx);
62 let configuration = copilot::copilot_chat::CopilotChatConfiguration {
63 enterprise_uri: language_settings
64 .edit_predictions
65 .copilot
66 .enterprise_uri
67 .clone(),
68 };
69 copilot_chat.update(cx, |chat, cx| {
70 chat.set_configuration(configuration, cx);
71 });
72 }
73 cx.notify();
74 }),
75 }
76 });
77
78 Self { state }
79 }
80
81 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
82 Arc::new(CopilotChatLanguageModel {
83 model,
84 request_limiter: RateLimiter::new(4),
85 })
86 }
87}
88
89impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
90 type ObservableEntity = State;
91
92 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
93 Some(self.state.clone())
94 }
95}
96
97impl LanguageModelProvider for CopilotChatLanguageModelProvider {
98 fn id(&self) -> LanguageModelProviderId {
99 PROVIDER_ID
100 }
101
102 fn name(&self) -> LanguageModelProviderName {
103 PROVIDER_NAME
104 }
105
106 fn icon(&self) -> IconName {
107 IconName::Copilot
108 }
109
110 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
111 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
112 models
113 .first()
114 .map(|model| self.create_language_model(model.clone()))
115 }
116
117 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
118 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
119 // model (e.g. 4o) and a sensible choice when considering premium requests
120 self.default_model(cx)
121 }
122
123 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
124 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
125 return Vec::new();
126 };
127 models
128 .iter()
129 .map(|model| self.create_language_model(model.clone()))
130 .collect()
131 }
132
133 fn is_authenticated(&self, cx: &App) -> bool {
134 self.state.read(cx).is_authenticated(cx)
135 }
136
137 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
138 if self.is_authenticated(cx) {
139 return Task::ready(Ok(()));
140 };
141
142 let Some(copilot) = Copilot::global(cx) else {
143 return Task::ready( Err(anyhow!(
144 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
145 ).into()));
146 };
147
148 let err = match copilot.read(cx).status() {
149 Status::Authorized => return Task::ready(Ok(())),
150 Status::Disabled => anyhow!(
151 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
152 ),
153 Status::Error(err) => anyhow!(format!(
154 "Received the following error while signing into Copilot: {err}"
155 )),
156 Status::Starting { task: _ } => anyhow!(
157 "Copilot is still starting, please wait for Copilot to start then try again"
158 ),
159 Status::Unauthorized => anyhow!(
160 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
161 ),
162 Status::SignedOut { .. } => {
163 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
164 }
165 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
166 };
167
168 Task::ready(Err(err.into()))
169 }
170
171 fn configuration_view(
172 &self,
173 _target_agent: language_model::ConfigurationViewTargetAgent,
174 _: &mut Window,
175 cx: &mut App,
176 ) -> AnyView {
177 let state = self.state.clone();
178 cx.new(|cx| ConfigurationView::new(state, cx)).into()
179 }
180
181 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
182 Task::ready(Err(anyhow!(
183 "Signing out of GitHub Copilot Chat is currently not supported."
184 )))
185 }
186}
187
188fn collect_tiktoken_messages(
189 request: LanguageModelRequest,
190) -> Vec<tiktoken_rs::ChatCompletionRequestMessage> {
191 request
192 .messages
193 .into_iter()
194 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
195 role: match message.role {
196 Role::User => "user".into(),
197 Role::Assistant => "assistant".into(),
198 Role::System => "system".into(),
199 },
200 content: Some(message.string_contents()),
201 name: None,
202 function_call: None,
203 })
204 .collect::<Vec<_>>()
205}
206
207pub struct CopilotChatLanguageModel {
208 model: CopilotChatModel,
209 request_limiter: RateLimiter,
210}
211
212impl LanguageModel for CopilotChatLanguageModel {
213 fn id(&self) -> LanguageModelId {
214 LanguageModelId::from(self.model.id().to_string())
215 }
216
217 fn name(&self) -> LanguageModelName {
218 LanguageModelName::from(self.model.display_name().to_string())
219 }
220
221 fn provider_id(&self) -> LanguageModelProviderId {
222 PROVIDER_ID
223 }
224
225 fn provider_name(&self) -> LanguageModelProviderName {
226 PROVIDER_NAME
227 }
228
229 fn supports_tools(&self) -> bool {
230 self.model.supports_tools()
231 }
232
233 fn supports_images(&self) -> bool {
234 self.model.supports_vision()
235 }
236
237 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
238 match self.model.vendor() {
239 ModelVendor::OpenAI | ModelVendor::Anthropic => {
240 LanguageModelToolSchemaFormat::JsonSchema
241 }
242 ModelVendor::Google | ModelVendor::XAI | ModelVendor::Unknown => {
243 LanguageModelToolSchemaFormat::JsonSchemaSubset
244 }
245 }
246 }
247
248 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
249 match choice {
250 LanguageModelToolChoice::Auto
251 | LanguageModelToolChoice::Any
252 | LanguageModelToolChoice::None => self.supports_tools(),
253 }
254 }
255
256 fn telemetry_id(&self) -> String {
257 format!("copilot_chat/{}", self.model.id())
258 }
259
260 fn max_token_count(&self) -> u64 {
261 self.model.max_token_count()
262 }
263
264 fn count_tokens(
265 &self,
266 request: LanguageModelRequest,
267 cx: &App,
268 ) -> BoxFuture<'static, Result<u64>> {
269 let model = self.model.clone();
270 cx.background_spawn(async move {
271 let messages = collect_tiktoken_messages(request);
272 // Copilot uses OpenAI tiktoken tokenizer for all it's model irrespective of the underlying provider(vendor).
273 let tokenizer_model = match model.tokenizer() {
274 Some("o200k_base") => "gpt-4o",
275 Some("cl100k_base") => "gpt-4",
276 _ => "gpt-4o",
277 };
278
279 tiktoken_rs::num_tokens_from_messages(tokenizer_model, &messages)
280 .map(|tokens| tokens as u64)
281 })
282 .boxed()
283 }
284
285 fn stream_completion(
286 &self,
287 request: LanguageModelRequest,
288 cx: &AsyncApp,
289 ) -> BoxFuture<
290 'static,
291 Result<
292 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
293 LanguageModelCompletionError,
294 >,
295 > {
296 let is_user_initiated = request.intent.is_none_or(|intent| match intent {
297 CompletionIntent::UserPrompt
298 | CompletionIntent::ThreadContextSummarization
299 | CompletionIntent::InlineAssist
300 | CompletionIntent::TerminalInlineAssist
301 | CompletionIntent::GenerateGitCommitMessage => true,
302
303 CompletionIntent::ToolResults
304 | CompletionIntent::ThreadSummarization
305 | CompletionIntent::CreateFile
306 | CompletionIntent::EditFile => false,
307 });
308
309 let copilot_request = match into_copilot_chat(&self.model, request) {
310 Ok(request) => request,
311 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
312 };
313 let is_streaming = copilot_request.stream;
314
315 let request_limiter = self.request_limiter.clone();
316 let future = cx.spawn(async move |cx| {
317 let request =
318 CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
319 request_limiter
320 .stream(async move {
321 let response = request.await?;
322 Ok(map_to_language_model_completion_events(
323 response,
324 is_streaming,
325 ))
326 })
327 .await
328 });
329 async move { Ok(future.await?.boxed()) }.boxed()
330 }
331}
332
333pub fn map_to_language_model_completion_events(
334 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
335 is_streaming: bool,
336) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
337 #[derive(Default)]
338 struct RawToolCall {
339 id: String,
340 name: String,
341 arguments: String,
342 }
343
344 struct State {
345 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
346 tool_calls_by_index: HashMap<usize, RawToolCall>,
347 }
348
349 futures::stream::unfold(
350 State {
351 events,
352 tool_calls_by_index: HashMap::default(),
353 },
354 move |mut state| async move {
355 if let Some(event) = state.events.next().await {
356 match event {
357 Ok(event) => {
358 let Some(choice) = event.choices.first() else {
359 return Some((
360 vec![Err(anyhow!("Response contained no choices").into())],
361 state,
362 ));
363 };
364
365 let delta = if is_streaming {
366 choice.delta.as_ref()
367 } else {
368 choice.message.as_ref()
369 };
370
371 let Some(delta) = delta else {
372 return Some((
373 vec![Err(anyhow!("Response contained no delta").into())],
374 state,
375 ));
376 };
377
378 let mut events = Vec::new();
379 if let Some(content) = delta.content.clone() {
380 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
381 }
382
383 for tool_call in &delta.tool_calls {
384 let entry = state
385 .tool_calls_by_index
386 .entry(tool_call.index)
387 .or_default();
388
389 if let Some(tool_id) = tool_call.id.clone() {
390 entry.id = tool_id;
391 }
392
393 if let Some(function) = tool_call.function.as_ref() {
394 if let Some(name) = function.name.clone() {
395 entry.name = name;
396 }
397
398 if let Some(arguments) = function.arguments.clone() {
399 entry.arguments.push_str(&arguments);
400 }
401 }
402 }
403
404 if let Some(usage) = event.usage {
405 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
406 TokenUsage {
407 input_tokens: usage.prompt_tokens,
408 output_tokens: usage.completion_tokens,
409 cache_creation_input_tokens: 0,
410 cache_read_input_tokens: 0,
411 },
412 )));
413 }
414
415 match choice.finish_reason.as_deref() {
416 Some("stop") => {
417 events.push(Ok(LanguageModelCompletionEvent::Stop(
418 StopReason::EndTurn,
419 )));
420 }
421 Some("tool_calls") => {
422 events.extend(state.tool_calls_by_index.drain().map(
423 |(_, tool_call)| {
424 // The model can output an empty string
425 // to indicate the absence of arguments.
426 // When that happens, create an empty
427 // object instead.
428 let arguments = if tool_call.arguments.is_empty() {
429 Ok(serde_json::Value::Object(Default::default()))
430 } else {
431 serde_json::Value::from_str(&tool_call.arguments)
432 };
433 match arguments {
434 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
435 LanguageModelToolUse {
436 id: tool_call.id.clone().into(),
437 name: tool_call.name.as_str().into(),
438 is_input_complete: true,
439 input,
440 raw_input: tool_call.arguments.clone(),
441 },
442 )),
443 Err(error) => Ok(
444 LanguageModelCompletionEvent::ToolUseJsonParseError {
445 id: tool_call.id.into(),
446 tool_name: tool_call.name.as_str().into(),
447 raw_input: tool_call.arguments.into(),
448 json_parse_error: error.to_string(),
449 },
450 ),
451 }
452 },
453 ));
454
455 events.push(Ok(LanguageModelCompletionEvent::Stop(
456 StopReason::ToolUse,
457 )));
458 }
459 Some(stop_reason) => {
460 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
461 events.push(Ok(LanguageModelCompletionEvent::Stop(
462 StopReason::EndTurn,
463 )));
464 }
465 None => {}
466 }
467
468 return Some((events, state));
469 }
470 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
471 }
472 }
473
474 None
475 },
476 )
477 .flat_map(futures::stream::iter)
478}
479
480fn into_copilot_chat(
481 model: &copilot::copilot_chat::Model,
482 request: LanguageModelRequest,
483) -> Result<CopilotChatRequest> {
484 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
485 for message in request.messages {
486 if let Some(last_message) = request_messages.last_mut() {
487 if last_message.role == message.role {
488 last_message.content.extend(message.content);
489 } else {
490 request_messages.push(message);
491 }
492 } else {
493 request_messages.push(message);
494 }
495 }
496
497 let mut messages: Vec<ChatMessage> = Vec::new();
498 for message in request_messages {
499 match message.role {
500 Role::User => {
501 for content in &message.content {
502 if let MessageContent::ToolResult(tool_result) = content {
503 let content = match &tool_result.content {
504 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
505 LanguageModelToolResultContent::Image(image) => {
506 if model.supports_vision() {
507 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
508 image_url: ImageUrl {
509 url: image.to_base64_url(),
510 },
511 }])
512 } else {
513 debug_panic!(
514 "This should be caught at {} level",
515 tool_result.tool_name
516 );
517 "[Tool responded with an image, but this model does not support vision]".to_string().into()
518 }
519 }
520 };
521
522 messages.push(ChatMessage::Tool {
523 tool_call_id: tool_result.tool_use_id.to_string(),
524 content,
525 });
526 }
527 }
528
529 let mut content_parts = Vec::new();
530 for content in &message.content {
531 match content {
532 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
533 if !text.is_empty() =>
534 {
535 if let Some(ChatMessagePart::Text { text: text_content }) =
536 content_parts.last_mut()
537 {
538 text_content.push_str(text);
539 } else {
540 content_parts.push(ChatMessagePart::Text {
541 text: text.to_string(),
542 });
543 }
544 }
545 MessageContent::Image(image) if model.supports_vision() => {
546 content_parts.push(ChatMessagePart::Image {
547 image_url: ImageUrl {
548 url: image.to_base64_url(),
549 },
550 });
551 }
552 _ => {}
553 }
554 }
555
556 if !content_parts.is_empty() {
557 messages.push(ChatMessage::User {
558 content: content_parts.into(),
559 });
560 }
561 }
562 Role::Assistant => {
563 let mut tool_calls = Vec::new();
564 for content in &message.content {
565 if let MessageContent::ToolUse(tool_use) = content {
566 tool_calls.push(ToolCall {
567 id: tool_use.id.to_string(),
568 content: copilot::copilot_chat::ToolCallContent::Function {
569 function: copilot::copilot_chat::FunctionContent {
570 name: tool_use.name.to_string(),
571 arguments: serde_json::to_string(&tool_use.input)?,
572 },
573 },
574 });
575 }
576 }
577
578 let text_content = {
579 let mut buffer = String::new();
580 for string in message.content.iter().filter_map(|content| match content {
581 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
582 Some(text.as_str())
583 }
584 MessageContent::ToolUse(_)
585 | MessageContent::RedactedThinking(_)
586 | MessageContent::ToolResult(_)
587 | MessageContent::Image(_) => None,
588 }) {
589 buffer.push_str(string);
590 }
591
592 buffer
593 };
594
595 messages.push(ChatMessage::Assistant {
596 content: if text_content.is_empty() {
597 ChatMessageContent::empty()
598 } else {
599 text_content.into()
600 },
601 tool_calls,
602 });
603 }
604 Role::System => messages.push(ChatMessage::System {
605 content: message.string_contents(),
606 }),
607 }
608 }
609
610 let tools = request
611 .tools
612 .iter()
613 .map(|tool| Tool::Function {
614 function: copilot::copilot_chat::Function {
615 name: tool.name.clone(),
616 description: tool.description.clone(),
617 parameters: tool.input_schema.clone(),
618 },
619 })
620 .collect::<Vec<_>>();
621
622 Ok(CopilotChatRequest {
623 intent: true,
624 n: 1,
625 stream: model.uses_streaming(),
626 temperature: 0.1,
627 model: model.id().to_string(),
628 messages,
629 tools,
630 tool_choice: request.tool_choice.map(|choice| match choice {
631 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
632 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
633 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
634 }),
635 })
636}
637
638struct ConfigurationView {
639 copilot_status: Option<copilot::Status>,
640 state: Entity<State>,
641 _subscription: Option<Subscription>,
642}
643
644impl ConfigurationView {
645 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
646 let copilot = Copilot::global(cx);
647
648 Self {
649 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
650 state,
651 _subscription: copilot.as_ref().map(|copilot| {
652 cx.observe(copilot, |this, model, cx| {
653 this.copilot_status = Some(model.read(cx).status());
654 cx.notify();
655 })
656 }),
657 }
658 }
659}
660
661impl Render for ConfigurationView {
662 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
663 if self.state.read(cx).is_authenticated(cx) {
664 h_flex()
665 .mt_1()
666 .p_1()
667 .justify_between()
668 .rounded_md()
669 .border_1()
670 .border_color(cx.theme().colors().border)
671 .bg(cx.theme().colors().background)
672 .child(
673 h_flex()
674 .gap_1()
675 .child(Icon::new(IconName::Check).color(Color::Success))
676 .child(Label::new("Authorized")),
677 )
678 .child(
679 Button::new("sign_out", "Sign Out")
680 .label_size(LabelSize::Small)
681 .on_click(|_, window, cx| {
682 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
683 }),
684 )
685 } else {
686 let loading_icon = Icon::new(IconName::ArrowCircle).with_rotate_animation(4);
687
688 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
689
690 match &self.copilot_status {
691 Some(status) => match status {
692 Status::Starting { task: _ } => h_flex()
693 .gap_2()
694 .child(loading_icon)
695 .child(Label::new("Starting Copilot…")),
696 Status::SigningIn { prompt: _ }
697 | Status::SignedOut {
698 awaiting_signing_in: true,
699 } => h_flex()
700 .gap_2()
701 .child(loading_icon)
702 .child(Label::new("Signing into Copilot…")),
703 Status::Error(_) => {
704 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
705 v_flex()
706 .gap_6()
707 .child(Label::new(LABEL))
708 .child(svg().size_8().path(IconName::CopilotError.path()))
709 }
710 _ => {
711 const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
712
713 v_flex().gap_2().child(Label::new(LABEL)).child(
714 Button::new("sign_in", "Sign in to use GitHub Copilot")
715 .icon_color(Color::Muted)
716 .icon(IconName::Github)
717 .icon_position(IconPosition::Start)
718 .icon_size(IconSize::Medium)
719 .full_width()
720 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
721 )
722 }
723 },
724 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
725 }
726 }
727 }
728}