1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use cloud_llm_client::CompletionIntent;
7use collections::HashMap;
8use copilot::copilot_chat::{
9 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
10 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
11 ToolCall,
12};
13use copilot::{Copilot, Status};
14use futures::future::BoxFuture;
15use futures::stream::BoxStream;
16use futures::{FutureExt, Stream, StreamExt};
17use gpui::{
18 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
19 Transformation, percentage, svg,
20};
21use language::language_settings::all_language_settings;
22use language_model::{
23 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
24 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
25 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
27 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
28 StopReason, TokenUsage,
29};
30use settings::SettingsStore;
31use std::time::Duration;
32use ui::prelude::*;
33use util::debug_panic;
34
35use super::anthropic::count_anthropic_tokens;
36use super::google::count_google_tokens;
37use super::open_ai::count_open_ai_tokens;
38
39const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat");
40const PROVIDER_NAME: LanguageModelProviderName =
41 LanguageModelProviderName::new("GitHub Copilot Chat");
42
43pub struct CopilotChatLanguageModelProvider {
44 state: Entity<State>,
45}
46
47pub struct State {
48 _copilot_chat_subscription: Option<Subscription>,
49 _settings_subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self, cx: &App) -> bool {
54 CopilotChat::global(cx)
55 .map(|m| m.read(cx).is_authenticated())
56 .unwrap_or(false)
57 }
58}
59
60impl CopilotChatLanguageModelProvider {
61 pub fn new(cx: &mut App) -> Self {
62 let state = cx.new(|cx| {
63 let copilot_chat_subscription = CopilotChat::global(cx)
64 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
65 State {
66 _copilot_chat_subscription: copilot_chat_subscription,
67 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
68 if let Some(copilot_chat) = CopilotChat::global(cx) {
69 let language_settings = all_language_settings(None, cx);
70 let configuration = copilot::copilot_chat::CopilotChatConfiguration {
71 enterprise_uri: language_settings
72 .edit_predictions
73 .copilot
74 .enterprise_uri
75 .clone(),
76 };
77 copilot_chat.update(cx, |chat, cx| {
78 chat.set_configuration(configuration, cx);
79 });
80 }
81 cx.notify();
82 }),
83 }
84 });
85
86 Self { state }
87 }
88
89 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
90 Arc::new(CopilotChatLanguageModel {
91 model,
92 request_limiter: RateLimiter::new(4),
93 })
94 }
95}
96
97impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
98 type ObservableEntity = State;
99
100 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
101 Some(self.state.clone())
102 }
103}
104
105impl LanguageModelProvider for CopilotChatLanguageModelProvider {
106 fn id(&self) -> LanguageModelProviderId {
107 PROVIDER_ID
108 }
109
110 fn name(&self) -> LanguageModelProviderName {
111 PROVIDER_NAME
112 }
113
114 fn icon(&self) -> IconName {
115 IconName::Copilot
116 }
117
118 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
119 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
120 models
121 .first()
122 .map(|model| self.create_language_model(model.clone()))
123 }
124
125 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
126 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
127 // model (e.g. 4o) and a sensible choice when considering premium requests
128 self.default_model(cx)
129 }
130
131 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
132 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
133 return Vec::new();
134 };
135 models
136 .iter()
137 .map(|model| self.create_language_model(model.clone()))
138 .collect()
139 }
140
141 fn is_authenticated(&self, cx: &App) -> bool {
142 self.state.read(cx).is_authenticated(cx)
143 }
144
145 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
146 if self.is_authenticated(cx) {
147 return Task::ready(Ok(()));
148 };
149
150 let Some(copilot) = Copilot::global(cx) else {
151 return Task::ready( Err(anyhow!(
152 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
153 ).into()));
154 };
155
156 let err = match copilot.read(cx).status() {
157 Status::Authorized => return Task::ready(Ok(())),
158 Status::Disabled => anyhow!(
159 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
160 ),
161 Status::Error(err) => anyhow!(format!(
162 "Received the following error while signing into Copilot: {err}"
163 )),
164 Status::Starting { task: _ } => anyhow!(
165 "Copilot is still starting, please wait for Copilot to start then try again"
166 ),
167 Status::Unauthorized => anyhow!(
168 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
169 ),
170 Status::SignedOut { .. } => {
171 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
172 }
173 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
174 };
175
176 Task::ready(Err(err.into()))
177 }
178
179 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
180 let state = self.state.clone();
181 cx.new(|cx| ConfigurationView::new(state, cx)).into()
182 }
183
184 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
185 Task::ready(Err(anyhow!(
186 "Signing out of GitHub Copilot Chat is currently not supported."
187 )))
188 }
189}
190
191pub struct CopilotChatLanguageModel {
192 model: CopilotChatModel,
193 request_limiter: RateLimiter,
194}
195
196impl LanguageModel for CopilotChatLanguageModel {
197 fn id(&self) -> LanguageModelId {
198 LanguageModelId::from(self.model.id().to_string())
199 }
200
201 fn name(&self) -> LanguageModelName {
202 LanguageModelName::from(self.model.display_name().to_string())
203 }
204
205 fn provider_id(&self) -> LanguageModelProviderId {
206 PROVIDER_ID
207 }
208
209 fn provider_name(&self) -> LanguageModelProviderName {
210 PROVIDER_NAME
211 }
212
213 fn supports_tools(&self) -> bool {
214 self.model.supports_tools()
215 }
216
217 fn supports_images(&self) -> bool {
218 self.model.supports_vision()
219 }
220
221 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
222 match self.model.vendor() {
223 ModelVendor::OpenAI | ModelVendor::Anthropic => {
224 LanguageModelToolSchemaFormat::JsonSchema
225 }
226 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
227 }
228 }
229
230 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
231 match choice {
232 LanguageModelToolChoice::Auto
233 | LanguageModelToolChoice::Any
234 | LanguageModelToolChoice::None => self.supports_tools(),
235 }
236 }
237
238 fn telemetry_id(&self) -> String {
239 format!("copilot_chat/{}", self.model.id())
240 }
241
242 fn max_token_count(&self) -> u64 {
243 self.model.max_token_count()
244 }
245
246 fn count_tokens(
247 &self,
248 request: LanguageModelRequest,
249 cx: &App,
250 ) -> BoxFuture<'static, Result<u64>> {
251 match self.model.vendor() {
252 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
253 ModelVendor::Google => count_google_tokens(request, cx),
254 ModelVendor::OpenAI => {
255 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
256 count_open_ai_tokens(request, model, cx)
257 }
258 }
259 }
260
261 fn stream_completion(
262 &self,
263 request: LanguageModelRequest,
264 cx: &AsyncApp,
265 ) -> BoxFuture<
266 'static,
267 Result<
268 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
269 LanguageModelCompletionError,
270 >,
271 > {
272 let is_user_initiated = request.intent.is_none_or(|intent| match intent {
273 CompletionIntent::UserPrompt
274 | CompletionIntent::ThreadContextSummarization
275 | CompletionIntent::InlineAssist
276 | CompletionIntent::TerminalInlineAssist
277 | CompletionIntent::GenerateGitCommitMessage => true,
278
279 CompletionIntent::ToolResults
280 | CompletionIntent::ThreadSummarization
281 | CompletionIntent::CreateFile
282 | CompletionIntent::EditFile => false,
283 });
284
285 let copilot_request = match into_copilot_chat(&self.model, request) {
286 Ok(request) => request,
287 Err(err) => return futures::future::ready(Err(err.into())).boxed(),
288 };
289 let is_streaming = copilot_request.stream;
290
291 let request_limiter = self.request_limiter.clone();
292 let future = cx.spawn(async move |cx| {
293 let request =
294 CopilotChat::stream_completion(copilot_request, is_user_initiated, cx.clone());
295 request_limiter
296 .stream(async move {
297 let response = request.await?;
298 Ok(map_to_language_model_completion_events(
299 response,
300 is_streaming,
301 ))
302 })
303 .await
304 });
305 async move { Ok(future.await?.boxed()) }.boxed()
306 }
307}
308
309pub fn map_to_language_model_completion_events(
310 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
311 is_streaming: bool,
312) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
313 #[derive(Default)]
314 struct RawToolCall {
315 id: String,
316 name: String,
317 arguments: String,
318 }
319
320 struct State {
321 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
322 tool_calls_by_index: HashMap<usize, RawToolCall>,
323 }
324
325 futures::stream::unfold(
326 State {
327 events,
328 tool_calls_by_index: HashMap::default(),
329 },
330 move |mut state| async move {
331 if let Some(event) = state.events.next().await {
332 match event {
333 Ok(event) => {
334 let Some(choice) = event.choices.first() else {
335 return Some((
336 vec![Err(anyhow!("Response contained no choices").into())],
337 state,
338 ));
339 };
340
341 let delta = if is_streaming {
342 choice.delta.as_ref()
343 } else {
344 choice.message.as_ref()
345 };
346
347 let Some(delta) = delta else {
348 return Some((
349 vec![Err(anyhow!("Response contained no delta").into())],
350 state,
351 ));
352 };
353
354 let mut events = Vec::new();
355 if let Some(content) = delta.content.clone() {
356 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
357 }
358
359 for tool_call in &delta.tool_calls {
360 let entry = state
361 .tool_calls_by_index
362 .entry(tool_call.index)
363 .or_default();
364
365 if let Some(tool_id) = tool_call.id.clone() {
366 entry.id = tool_id;
367 }
368
369 if let Some(function) = tool_call.function.as_ref() {
370 if let Some(name) = function.name.clone() {
371 entry.name = name;
372 }
373
374 if let Some(arguments) = function.arguments.clone() {
375 entry.arguments.push_str(&arguments);
376 }
377 }
378 }
379
380 if let Some(usage) = event.usage {
381 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
382 TokenUsage {
383 input_tokens: usage.prompt_tokens,
384 output_tokens: usage.completion_tokens,
385 cache_creation_input_tokens: 0,
386 cache_read_input_tokens: 0,
387 },
388 )));
389 }
390
391 match choice.finish_reason.as_deref() {
392 Some("stop") => {
393 events.push(Ok(LanguageModelCompletionEvent::Stop(
394 StopReason::EndTurn,
395 )));
396 }
397 Some("tool_calls") => {
398 events.extend(state.tool_calls_by_index.drain().map(
399 |(_, tool_call)| {
400 // The model can output an empty string
401 // to indicate the absence of arguments.
402 // When that happens, create an empty
403 // object instead.
404 let arguments = if tool_call.arguments.is_empty() {
405 Ok(serde_json::Value::Object(Default::default()))
406 } else {
407 serde_json::Value::from_str(&tool_call.arguments)
408 };
409 match arguments {
410 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
411 LanguageModelToolUse {
412 id: tool_call.id.clone().into(),
413 name: tool_call.name.as_str().into(),
414 is_input_complete: true,
415 input,
416 raw_input: tool_call.arguments.clone(),
417 },
418 )),
419 Err(error) => Ok(
420 LanguageModelCompletionEvent::ToolUseJsonParseError {
421 id: tool_call.id.into(),
422 tool_name: tool_call.name.as_str().into(),
423 raw_input: tool_call.arguments.into(),
424 json_parse_error: error.to_string(),
425 },
426 ),
427 }
428 },
429 ));
430
431 events.push(Ok(LanguageModelCompletionEvent::Stop(
432 StopReason::ToolUse,
433 )));
434 }
435 Some(stop_reason) => {
436 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
437 events.push(Ok(LanguageModelCompletionEvent::Stop(
438 StopReason::EndTurn,
439 )));
440 }
441 None => {}
442 }
443
444 return Some((events, state));
445 }
446 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
447 }
448 }
449
450 None
451 },
452 )
453 .flat_map(futures::stream::iter)
454}
455
456fn into_copilot_chat(
457 model: &copilot::copilot_chat::Model,
458 request: LanguageModelRequest,
459) -> Result<CopilotChatRequest> {
460 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
461 for message in request.messages {
462 if let Some(last_message) = request_messages.last_mut() {
463 if last_message.role == message.role {
464 last_message.content.extend(message.content);
465 } else {
466 request_messages.push(message);
467 }
468 } else {
469 request_messages.push(message);
470 }
471 }
472
473 let mut tool_called = false;
474 let mut messages: Vec<ChatMessage> = Vec::new();
475 for message in request_messages {
476 match message.role {
477 Role::User => {
478 for content in &message.content {
479 if let MessageContent::ToolResult(tool_result) = content {
480 let content = match &tool_result.content {
481 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
482 LanguageModelToolResultContent::Image(image) => {
483 if model.supports_vision() {
484 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
485 image_url: ImageUrl {
486 url: image.to_base64_url(),
487 },
488 }])
489 } else {
490 debug_panic!(
491 "This should be caught at {} level",
492 tool_result.tool_name
493 );
494 "[Tool responded with an image, but this model does not support vision]".to_string().into()
495 }
496 }
497 };
498
499 messages.push(ChatMessage::Tool {
500 tool_call_id: tool_result.tool_use_id.to_string(),
501 content,
502 });
503 }
504 }
505
506 let mut content_parts = Vec::new();
507 for content in &message.content {
508 match content {
509 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
510 if !text.is_empty() =>
511 {
512 if let Some(ChatMessagePart::Text { text: text_content }) =
513 content_parts.last_mut()
514 {
515 text_content.push_str(text);
516 } else {
517 content_parts.push(ChatMessagePart::Text {
518 text: text.to_string(),
519 });
520 }
521 }
522 MessageContent::Image(image) if model.supports_vision() => {
523 content_parts.push(ChatMessagePart::Image {
524 image_url: ImageUrl {
525 url: image.to_base64_url(),
526 },
527 });
528 }
529 _ => {}
530 }
531 }
532
533 if !content_parts.is_empty() {
534 messages.push(ChatMessage::User {
535 content: content_parts.into(),
536 });
537 }
538 }
539 Role::Assistant => {
540 let mut tool_calls = Vec::new();
541 for content in &message.content {
542 if let MessageContent::ToolUse(tool_use) = content {
543 tool_called = true;
544 tool_calls.push(ToolCall {
545 id: tool_use.id.to_string(),
546 content: copilot::copilot_chat::ToolCallContent::Function {
547 function: copilot::copilot_chat::FunctionContent {
548 name: tool_use.name.to_string(),
549 arguments: serde_json::to_string(&tool_use.input)?,
550 },
551 },
552 });
553 }
554 }
555
556 let text_content = {
557 let mut buffer = String::new();
558 for string in message.content.iter().filter_map(|content| match content {
559 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
560 Some(text.as_str())
561 }
562 MessageContent::ToolUse(_)
563 | MessageContent::RedactedThinking(_)
564 | MessageContent::ToolResult(_)
565 | MessageContent::Image(_) => None,
566 }) {
567 buffer.push_str(string);
568 }
569
570 buffer
571 };
572
573 messages.push(ChatMessage::Assistant {
574 content: if text_content.is_empty() {
575 ChatMessageContent::empty()
576 } else {
577 text_content.into()
578 },
579 tool_calls,
580 });
581 }
582 Role::System => messages.push(ChatMessage::System {
583 content: message.string_contents(),
584 }),
585 }
586 }
587
588 let mut tools = request
589 .tools
590 .iter()
591 .map(|tool| Tool::Function {
592 function: copilot::copilot_chat::Function {
593 name: tool.name.clone(),
594 description: tool.description.clone(),
595 parameters: tool.input_schema.clone(),
596 },
597 })
598 .collect::<Vec<_>>();
599
600 // The API will return a Bad Request (with no error message) when tools
601 // were used previously in the conversation but no tools are provided as
602 // part of this request. Inserting a dummy tool seems to circumvent this
603 // error.
604 if tool_called && tools.is_empty() {
605 tools.push(Tool::Function {
606 function: copilot::copilot_chat::Function {
607 name: "noop".to_string(),
608 description: "No operation".to_string(),
609 parameters: serde_json::json!({
610 "type": "object"
611 }),
612 },
613 });
614 }
615
616 Ok(CopilotChatRequest {
617 intent: true,
618 n: 1,
619 stream: model.uses_streaming(),
620 temperature: 0.1,
621 model: model.id().to_string(),
622 messages,
623 tools,
624 tool_choice: request.tool_choice.map(|choice| match choice {
625 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
626 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
627 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
628 }),
629 })
630}
631
632struct ConfigurationView {
633 copilot_status: Option<copilot::Status>,
634 state: Entity<State>,
635 _subscription: Option<Subscription>,
636}
637
638impl ConfigurationView {
639 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
640 let copilot = Copilot::global(cx);
641
642 Self {
643 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
644 state,
645 _subscription: copilot.as_ref().map(|copilot| {
646 cx.observe(copilot, |this, model, cx| {
647 this.copilot_status = Some(model.read(cx).status());
648 cx.notify();
649 })
650 }),
651 }
652 }
653}
654
655impl Render for ConfigurationView {
656 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
657 if self.state.read(cx).is_authenticated(cx) {
658 h_flex()
659 .mt_1()
660 .p_1()
661 .justify_between()
662 .rounded_md()
663 .border_1()
664 .border_color(cx.theme().colors().border)
665 .bg(cx.theme().colors().background)
666 .child(
667 h_flex()
668 .gap_1()
669 .child(Icon::new(IconName::Check).color(Color::Success))
670 .child(Label::new("Authorized")),
671 )
672 .child(
673 Button::new("sign_out", "Sign Out")
674 .label_size(LabelSize::Small)
675 .on_click(|_, window, cx| {
676 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
677 }),
678 )
679 } else {
680 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
681 "arrow-circle",
682 Animation::new(Duration::from_secs(4)).repeat(),
683 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
684 );
685
686 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
687
688 match &self.copilot_status {
689 Some(status) => match status {
690 Status::Starting { task: _ } => h_flex()
691 .gap_2()
692 .child(loading_icon)
693 .child(Label::new("Starting Copilot…")),
694 Status::SigningIn { prompt: _ }
695 | Status::SignedOut {
696 awaiting_signing_in: true,
697 } => h_flex()
698 .gap_2()
699 .child(loading_icon)
700 .child(Label::new("Signing into Copilot…")),
701 Status::Error(_) => {
702 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
703 v_flex()
704 .gap_6()
705 .child(Label::new(LABEL))
706 .child(svg().size_8().path(IconName::CopilotError.path()))
707 }
708 _ => {
709 const LABEL: &str = "To use Zed's agent with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
710
711 v_flex().gap_2().child(Label::new(LABEL)).child(
712 Button::new("sign_in", "Sign in to use GitHub Copilot")
713 .icon_color(Color::Muted)
714 .icon(IconName::Github)
715 .icon_position(IconPosition::Start)
716 .icon_size(IconSize::Medium)
717 .full_width()
718 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
719 )
720 }
721 },
722 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
723 }
724 }
725 }
726}