1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use cloud_llm_client::CompletionIntent;
7use collections::HashMap;
8use copilot::copilot_chat::{
9 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
10 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
11 ToolCall,
12};
13use copilot::{Copilot, Status};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, Stream, StreamExt};
17use gpui::{
18 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
19 Transformation, percentage, svg,
20};
21use language::language_settings::all_language_settings;
22use language_model::{
23 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
27 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
28 StopReason, TokenUsage,
29};
30use settings::SettingsStore;
31use std::time::Duration;
32use ui::prelude::*;
33use util::debug_panic;
34
35use super::anthropic::count_anthropic_tokens;
36use super::google::count_google_tokens;
37use super::open_ai::count_open_ai_tokens;
38
39const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
40const PROVIDER_NAME: LanguageModelProviderName =
41 LanguageModelProviderName::new("GitHub Copilot Chat");
42
43pub struct CopilotChatLanguageModelProvider {
44 state: Entity<State>,
45}
46
47pub struct State {
48 _copilot_chat_subscription: Option<Subscription>,
49 _settings_subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self, cx: &App) -> bool {
54 CopilotChat::global(cx)
55 .map(|m| m.read(cx).is_authenticated())
56 .unwrap_or(false)
57 }
58}
59
60impl CopilotChatLanguageModelProvider {
61 pub fn new(cx: &mut App) -> Self {
62 let state = cx.new(|cx| {
63 let copilot_chat_subscription = CopilotChat::global(cx)
64 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
65 State {
66 _copilot_chat_subscription: copilot_chat_subscription,
67 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
68 if let Some(copilot_chat) = CopilotChat::global(cx) {
69 let language_settings = all_language_settings(None, cx);
70 let configuration = copilot::copilot_chat::CopilotChatConfiguration {
71 enterprise_uri: language_settings
72 .edit_predictions
73 .copilot
74 .enterprise_uri
75 .clone(),
76 };
77 copilot_chat.update(cx, |chat, cx| {
78 chat.set_configuration(configuration, cx);
79 });
80 }
81 cx.notify();
82 }),
83 }
84 });
85
86 Self { state }
87 }
88
89 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
90 Arc::new(CopilotChatLanguageModel {
91 model,
92 request_limiter: RateLimiter::new(4),
93 })
94 }
95}
96
97impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
98 type ObservableEntity = State;
99
100 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
101 Some(self.state.clone())
102 }
103}
104
105impl LanguageModelProvider for CopilotChatLanguageModelProvider {
106 fn id(&self) -> LanguageModelProviderId {
107 PROVIDER_ID
108 }
109
110 fn name(&self) -> LanguageModelProviderName {
111 PROVIDER_NAME
112 }
113
114 fn icon(&self) -> IconName {
115 IconName::Copilot
116 }
117
118 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
119 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
120 models
121 .first()
122 .map(|model| self.create_language_model(model.clone()))
123 }
124
125 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
126 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
127 // model (e.g. 4o) and a sensible choice when considering premium requests
128 self.default_model(cx)
129 }
130
131 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
132 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
133 return Vec::new();
134 };
135 models
136 .iter()
137 .map(|model| self.create_language_model(model.clone()))
138 .collect()
139 }
140
141 fn is_authenticated(&self, cx: &App) -> bool {
142 self.state.read(cx).is_authenticated(cx)
143 }
144
145 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
146 if self.is_authenticated(cx) {
147 return Task::ready(Ok(()));
148 };
149
150 let Some(copilot) = Copilot::global(cx) else {
151 return Task::ready( Err(anyhow!(
152 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
153 ).into()));
154 };
155
156 let err = match copilot.read(cx).status() {
157 Status::Authorized => return Task::ready(Ok(())),
158 Status::Disabled => anyhow!(
159 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
160 ),
161 Status::Error(err) => anyhow!(format!(
162 "Received the following error while signing into Copilot: {err}"
163 )),
164 Status::Starting { task: _ } => anyhow!(
165 "Copilot is still starting, please wait for Copilot to start then try again"
166 ),
167 Status::Unauthorized => anyhow!(
168 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
169 ),
170 Status::SignedOut { .. } => {
171 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
172 }
173 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
174 };
175
176 Task::ready(Err(err.into()))
177 }
178
179 fn configuration_view(
180 &self,
181 _target_agent: language_model::ConfigurationViewTargetAgent,
182 _: &mut Window,
183 cx: &mut App,
184 ) -> AnyView {
185 let state = self.state.clone();
186 cx.new(|cx| ConfigurationView::new(state, cx)).into()
187 }
188
189 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
190 Task::ready(Err(anyhow!(
191 "Signing out of GitHub Copilot Chat is currently not supported."
192 )))
193 }
194}
195
196pub struct CopilotChatLanguageModel {
197 model: CopilotChatModel,
198 request_limiter: RateLimiter,
199}
200
201impl LanguageModel for CopilotChatLanguageModel {
202 fn id(&self) -> LanguageModelId {
203 LanguageModelId::from(self.model.id().to_string())
204 }
205
206 fn name(&self) -> LanguageModelName {
207 LanguageModelName::from(self.model.display_name().to_string())
208 }
209
210 fn provider_id(&self) -> LanguageModelProviderId {
211 PROVIDER_ID
212 }
213
214 fn provider_name(&self) -> LanguageModelProviderName {
215 PROVIDER_NAME
216 }
217
218 fn supports_tools(&self) -> bool {
219 self.model.supports_tools()
220 }
221
222 fn supports_images(&self) -> bool {
223 self.model.supports_vision()
224 }
225
226 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
227 match self.model.vendor() {
228 ModelVendor::OpenAI | ModelVendor::Anthropic => {
229 LanguageModelToolSchemaFormat::JsonSchema
230 }
231 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
232 }
233 }
234
235 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
236 match choice {
237 LanguageModelToolChoice::Auto
238 | LanguageModelToolChoice::Any
239 | LanguageModelToolChoice::None => self.supports_tools(),
240 }
241 }
242
243 fn telemetry_id(&self) -> String {
244 format!("copilot_chat/{}", self.model.id())
245 }
246
247 fn max_token_count(&self) -> u64 {
248 self.model.max_token_count()
249 }
250
251 fn count_tokens(
252 &self,
253 request: LanguageModelRequest,
254 cx: &App,
255 ) -> BoxFuture<'static, Result<u64>> {
256 match self.model.vendor() {
257 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
258 ModelVendor::Google => count_google_tokens(request, cx),
259 ModelVendor::OpenAI => {
260 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
261 count_open_ai_tokens(request, model, cx)
262 }
263 }
264 }
265
266 fn stream_completion(
267 &self,
268 request: LanguageModelRequest,
269 cx: &AsyncApp,
270 ) -> BoxFuture<
271 'static,
272 Result<
273 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
274 LanguageModelCompletionError,
275 >,
276 > {
277 let is_user_initiated = request.intent.is_none_or(|intent| match intent {
278 CompletionIntent::UserPrompt
279 | CompletionIntent::ThreadContextSummarization
280 | CompletionIntent::InlineAssist
281 | CompletionIntent::TerminalInlineAssist
282 | CompletionIntent::GenerateGitCommitMessage => true,
283
284 CompletionIntent::ToolResults
285 | CompletionIntent::ThreadSummarization
286 | CompletionIntent::CreateFile
287 | CompletionIntent::EditFile => false,
288 });
289
290 let copilot_request = match into_copilot_chat(&self.model, request) {
291 Ok(request) => request,
292 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
293 };
294 let is_streaming = copilot_request.stream;
295
296 let request_limiter = self.request_limiter.clone();
297 let future = cx.spawn(async move |cx| {
298 let request =
299 CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
300 request_limiter
301 .stream(async move {
302 let response = request.await?;
303 Ok(map_to_language_model_completion_events(
304 response,
305 is_streaming,
306 ))
307 })
308 .await
309 });
310 async move { Ok(future.await?.boxed()) }.boxed()
311 }
312}
313
314pub fn map_to_language_model_completion_events(
315 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
316 is_streaming: bool,
317) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
318 #[derive(Default)]
319 struct RawToolCall {
320 id: String,
321 name: String,
322 arguments: String,
323 }
324
325 struct State {
326 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
327 tool_calls_by_index: HashMap<usize, RawToolCall>,
328 }
329
330 futures::stream::unfold(
331 State {
332 events,
333 tool_calls_by_index: HashMap::default(),
334 },
335 move |mut state| async move {
336 if let Some(event) = state.events.next().await {
337 match event {
338 Ok(event) => {
339 let Some(choice) = event.choices.first() else {
340 return Some((
341 vec![Err(anyhow!("Response contained no choices").into())],
342 state,
343 ));
344 };
345
346 let delta = if is_streaming {
347 choice.delta.as_ref()
348 } else {
349 choice.message.as_ref()
350 };
351
352 let Some(delta) = delta else {
353 return Some((
354 vec![Err(anyhow!("Response contained no delta").into())],
355 state,
356 ));
357 };
358
359 let mut events = Vec::new();
360 if let Some(content) = delta.content.clone() {
361 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
362 }
363
364 for tool_call in &delta.tool_calls {
365 let entry = state
366 .tool_calls_by_index
367 .entry(tool_call.index)
368 .or_default();
369
370 if let Some(tool_id) = tool_call.id.clone() {
371 entry.id = tool_id;
372 }
373
374 if let Some(function) = tool_call.function.as_ref() {
375 if let Some(name) = function.name.clone() {
376 entry.name = name;
377 }
378
379 if let Some(arguments) = function.arguments.clone() {
380 entry.arguments.push_str(&arguments);
381 }
382 }
383 }
384
385 if let Some(usage) = event.usage {
386 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
387 TokenUsage {
388 input_tokens: usage.prompt_tokens,
389 output_tokens: usage.completion_tokens,
390 cache_creation_input_tokens: 0,
391 cache_read_input_tokens: 0,
392 },
393 )));
394 }
395
396 match choice.finish_reason.as_deref() {
397 Some("stop") => {
398 events.push(Ok(LanguageModelCompletionEvent::Stop(
399 StopReason::EndTurn,
400 )));
401 }
402 Some("tool_calls") => {
403 events.extend(state.tool_calls_by_index.drain().map(
404 |(_, tool_call)| {
405 // The model can output an empty string
406 // to indicate the absence of arguments.
407 // When that happens, create an empty
408 // object instead.
409 let arguments = if tool_call.arguments.is_empty() {
410 Ok(serde_json::Value::Object(Default::default()))
411 } else {
412 serde_json::Value::from_str(&tool_call.arguments)
413 };
414 match arguments {
415 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
416 LanguageModelToolUse {
417 id: tool_call.id.clone().into(),
418 name: tool_call.name.as_str().into(),
419 is_input_complete: true,
420 input,
421 raw_input: tool_call.arguments.clone(),
422 },
423 )),
424 Err(error) => Ok(
425 LanguageModelCompletionEvent::ToolUseJsonParseError {
426 id: tool_call.id.into(),
427 tool_name: tool_call.name.as_str().into(),
428 raw_input: tool_call.arguments.into(),
429 json_parse_error: error.to_string(),
430 },
431 ),
432 }
433 },
434 ));
435
436 events.push(Ok(LanguageModelCompletionEvent::Stop(
437 StopReason::ToolUse,
438 )));
439 }
440 Some(stop_reason) => {
441 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
442 events.push(Ok(LanguageModelCompletionEvent::Stop(
443 StopReason::EndTurn,
444 )));
445 }
446 None => {}
447 }
448
449 return Some((events, state));
450 }
451 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
452 }
453 }
454
455 None
456 },
457 )
458 .flat_map(futures::stream::iter)
459}
460
461fn into_copilot_chat(
462 model: &copilot::copilot_chat::Model,
463 request: LanguageModelRequest,
464) -> Result<CopilotChatRequest> {
465 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
466 for message in request.messages {
467 if let Some(last_message) = request_messages.last_mut() {
468 if last_message.role == message.role {
469 last_message.content.extend(message.content);
470 } else {
471 request_messages.push(message);
472 }
473 } else {
474 request_messages.push(message);
475 }
476 }
477
478 let mut messages: Vec<ChatMessage> = Vec::new();
479 for message in request_messages {
480 match message.role {
481 Role::User => {
482 for content in &message.content {
483 if let MessageContent::ToolResult(tool_result) = content {
484 let content = match &tool_result.content {
485 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
486 LanguageModelToolResultContent::Image(image) => {
487 if model.supports_vision() {
488 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
489 image_url: ImageUrl {
490 url: image.to_base64_url(),
491 },
492 }])
493 } else {
494 debug_panic!(
495 "This should be caught at {} level",
496 tool_result.tool_name
497 );
498 "[Tool responded with an image, but this model does not support vision]".to_string().into()
499 }
500 }
501 };
502
503 messages.push(ChatMessage::Tool {
504 tool_call_id: tool_result.tool_use_id.to_string(),
505 content,
506 });
507 }
508 }
509
510 let mut content_parts = Vec::new();
511 for content in &message.content {
512 match content {
513 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
514 if !text.is_empty() =>
515 {
516 if let Some(ChatMessagePart::Text { text: text_content }) =
517 content_parts.last_mut()
518 {
519 text_content.push_str(text);
520 } else {
521 content_parts.push(ChatMessagePart::Text {
522 text: text.to_string(),
523 });
524 }
525 }
526 MessageContent::Image(image) if model.supports_vision() => {
527 content_parts.push(ChatMessagePart::Image {
528 image_url: ImageUrl {
529 url: image.to_base64_url(),
530 },
531 });
532 }
533 _ => {}
534 }
535 }
536
537 if !content_parts.is_empty() {
538 messages.push(ChatMessage::User {
539 content: content_parts.into(),
540 });
541 }
542 }
543 Role::Assistant => {
544 let mut tool_calls = Vec::new();
545 for content in &message.content {
546 if let MessageContent::ToolUse(tool_use) = content {
547 tool_calls.push(ToolCall {
548 id: tool_use.id.to_string(),
549 content: copilot::copilot_chat::ToolCallContent::Function {
550 function: copilot::copilot_chat::FunctionContent {
551 name: tool_use.name.to_string(),
552 arguments: serde_json::to_string(&tool_use.input)?,
553 },
554 },
555 });
556 }
557 }
558
559 let text_content = {
560 let mut buffer = String::new();
561 for string in message.content.iter().filter_map(|content| match content {
562 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
563 Some(text.as_str())
564 }
565 MessageContent::ToolUse(_)
566 | MessageContent::RedactedThinking(_)
567 | MessageContent::ToolResult(_)
568 | MessageContent::Image(_) => None,
569 }) {
570 buffer.push_str(string);
571 }
572
573 buffer
574 };
575
576 messages.push(ChatMessage::Assistant {
577 content: if text_content.is_empty() {
578 ChatMessageContent::empty()
579 } else {
580 text_content.into()
581 },
582 tool_calls,
583 });
584 }
585 Role::System => messages.push(ChatMessage::System {
586 content: message.string_contents(),
587 }),
588 }
589 }
590
591 let tools = request
592 .tools
593 .iter()
594 .map(|tool| Tool::Function {
595 function: copilot::copilot_chat::Function {
596 name: tool.name.clone(),
597 description: tool.description.clone(),
598 parameters: tool.input_schema.clone(),
599 },
600 })
601 .collect::<Vec<_>>();
602
603 Ok(CopilotChatRequest {
604 intent: true,
605 n: 1,
606 stream: model.uses_streaming(),
607 temperature: 0.1,
608 model: model.id().to_string(),
609 messages,
610 tools,
611 tool_choice: request.tool_choice.map(|choice| match choice {
612 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
613 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
614 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
615 }),
616 })
617}
618
619struct ConfigurationView {
620 copilot_status: Option<copilot::Status>,
621 state: Entity<State>,
622 _subscription: Option<Subscription>,
623}
624
625impl ConfigurationView {
626 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
627 let copilot = Copilot::global(cx);
628
629 Self {
630 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
631 state,
632 _subscription: copilot.as_ref().map(|copilot| {
633 cx.observe(copilot, |this, model, cx| {
634 this.copilot_status = Some(model.read(cx).status());
635 cx.notify();
636 })
637 }),
638 }
639 }
640}
641
642impl Render for ConfigurationView {
643 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
644 if self.state.read(cx).is_authenticated(cx) {
645 h_flex()
646 .mt_1()
647 .p_1()
648 .justify_between()
649 .rounded_md()
650 .border_1()
651 .border_color(cx.theme().colors().border)
652 .bg(cx.theme().colors().background)
653 .child(
654 h_flex()
655 .gap_1()
656 .child(Icon::new(IconName::Check).color(Color::Success))
657 .child(Label::new("Authorized")),
658 )
659 .child(
660 Button::new("sign_out", "Sign Out")
661 .label_size(LabelSize::Small)
662 .on_click(|_, window, cx| {
663 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
664 }),
665 )
666 } else {
667 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
668 "arrow-circle",
669 Animation::new(Duration::from_secs(4)).repeat(),
670 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
671 );
672
673 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
674
675 match &self.copilot_status {
676 Some(status) => match status {
677 Status::Starting { task: _ } => h_flex()
678 .gap_2()
679 .child(loading_icon)
680 .child(Label::new("Starting Copilot…")),
681 Status::SigningIn { prompt: _ }
682 | Status::SignedOut {
683 awaiting_signing_in: true,
684 } => h_flex()
685 .gap_2()
686 .child(loading_icon)
687 .child(Label::new("Signing into Copilot…")),
688 Status::Error(_) => {
689 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
690 v_flex()
691 .gap_6()
692 .child(Label::new(LABEL))
693 .child(svg().size_8().path(IconName::CopilotError.path()))
694 }
695 _ => {
696 const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
697
698 v_flex().gap_2().child(Label::new(LABEL)).child(
699 Button::new("sign_in", "Sign in to use GitHub Copilot")
700 .icon_color(Color::Muted)
701 .icon(IconName::Github)
702 .icon_position(IconPosition::Start)
703 .icon_size(IconSize::Medium)
704 .full_width()
705 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
706 )
707 }
708 },
709 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
710 }
711 }
712 }
713}