1use std::pin::Pin;
2use std::str::FromStr as _;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::HashMap;
7use copilot::copilot_chat::{
8 ChatMessage, ChatMessageContent, ChatMessagePart, CopilotChat, ImageUrl,
9 Model as CopilotChatModel, ModelVendor, Request as CopilotChatRequest, ResponseEvent, Tool,
10 ToolCall,
11};
12use copilot::{Copilot, Status};
13use futures::future::BoxFuture;
14use futures::stream::BoxStream;
15use futures::{FutureExt, Stream, StreamExt};
16use gpui::{
17 Action, Animation, AnimationExt, AnyView, App, AsyncApp, Entity, Render, Subscription, Task,
18 Transformation, percentage, svg,
19};
20use language_model::{
21 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
22 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
23 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
24 LanguageModelRequestMessage, LanguageModelToolChoice, LanguageModelToolResultContent,
25 LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent, RateLimiter, Role,
26 StopReason,
27};
28use settings::SettingsStore;
29use std::time::Duration;
30use ui::prelude::*;
31use util::debug_panic;
32
33use super::anthropic::count_anthropic_tokens;
34use super::google::count_google_tokens;
35use super::open_ai::count_open_ai_tokens;
36
37const PROVIDER_ID: &str = "copilot_chat";
38const PROVIDER_NAME: &str = "GitHub Copilot Chat";
39
40#[derive(Default, Clone, Debug, PartialEq)]
41pub struct CopilotChatSettings {}
42
43pub struct CopilotChatLanguageModelProvider {
44 state: Entity<State>,
45}
46
47pub struct State {
48 _copilot_chat_subscription: Option<Subscription>,
49 _settings_subscription: Subscription,
50}
51
52impl State {
53 fn is_authenticated(&self, cx: &App) -> bool {
54 CopilotChat::global(cx)
55 .map(|m| m.read(cx).is_authenticated())
56 .unwrap_or(false)
57 }
58}
59
60impl CopilotChatLanguageModelProvider {
61 pub fn new(cx: &mut App) -> Self {
62 let state = cx.new(|cx| {
63 let copilot_chat_subscription = CopilotChat::global(cx)
64 .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify()));
65 State {
66 _copilot_chat_subscription: copilot_chat_subscription,
67 _settings_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
68 cx.notify();
69 }),
70 }
71 });
72
73 Self { state }
74 }
75
76 fn create_language_model(&self, model: CopilotChatModel) -> Arc<dyn LanguageModel> {
77 Arc::new(CopilotChatLanguageModel {
78 model,
79 request_limiter: RateLimiter::new(4),
80 })
81 }
82}
83
84impl LanguageModelProviderState for CopilotChatLanguageModelProvider {
85 type ObservableEntity = State;
86
87 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
88 Some(self.state.clone())
89 }
90}
91
92impl LanguageModelProvider for CopilotChatLanguageModelProvider {
93 fn id(&self) -> LanguageModelProviderId {
94 LanguageModelProviderId(PROVIDER_ID.into())
95 }
96
97 fn name(&self) -> LanguageModelProviderName {
98 LanguageModelProviderName(PROVIDER_NAME.into())
99 }
100
101 fn icon(&self) -> IconName {
102 IconName::Copilot
103 }
104
105 fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
106 let models = CopilotChat::global(cx).and_then(|m| m.read(cx).models())?;
107 models
108 .first()
109 .map(|model| self.create_language_model(model.clone()))
110 }
111
112 fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
113 // The default model should be Copilot Chat's 'base model', which is likely a relatively fast
114 // model (e.g. 4o) and a sensible choice when considering premium requests
115 self.default_model(cx)
116 }
117
118 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
119 let Some(models) = CopilotChat::global(cx).and_then(|m| m.read(cx).models()) else {
120 return Vec::new();
121 };
122 models
123 .iter()
124 .map(|model| self.create_language_model(model.clone()))
125 .collect()
126 }
127
128 fn is_authenticated(&self, cx: &App) -> bool {
129 self.state.read(cx).is_authenticated(cx)
130 }
131
132 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
133 if self.is_authenticated(cx) {
134 return Task::ready(Ok(()));
135 };
136
137 let Some(copilot) = Copilot::global(cx) else {
138 return Task::ready( Err(anyhow!(
139 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
140 ).into()));
141 };
142
143 let err = match copilot.read(cx).status() {
144 Status::Authorized => return Task::ready(Ok(())),
145 Status::Disabled => anyhow!(
146 "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."
147 ),
148 Status::Error(err) => anyhow!(format!(
149 "Received the following error while signing into Copilot: {err}"
150 )),
151 Status::Starting { task: _ } => anyhow!(
152 "Copilot is still starting, please wait for Copilot to start then try again"
153 ),
154 Status::Unauthorized => anyhow!(
155 "Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."
156 ),
157 Status::SignedOut { .. } => {
158 anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again.")
159 }
160 Status::SigningIn { prompt: _ } => anyhow!("Still signing into Copilot..."),
161 };
162
163 Task::ready(Err(err.into()))
164 }
165
166 fn configuration_view(&self, _: &mut Window, cx: &mut App) -> AnyView {
167 let state = self.state.clone();
168 cx.new(|cx| ConfigurationView::new(state, cx)).into()
169 }
170
171 fn reset_credentials(&self, _cx: &mut App) -> Task<Result<()>> {
172 Task::ready(Err(anyhow!(
173 "Signing out of GitHub Copilot Chat is currently not supported."
174 )))
175 }
176}
177
178pub struct CopilotChatLanguageModel {
179 model: CopilotChatModel,
180 request_limiter: RateLimiter,
181}
182
183impl LanguageModel for CopilotChatLanguageModel {
184 fn id(&self) -> LanguageModelId {
185 LanguageModelId::from(self.model.id().to_string())
186 }
187
188 fn name(&self) -> LanguageModelName {
189 LanguageModelName::from(self.model.display_name().to_string())
190 }
191
192 fn provider_id(&self) -> LanguageModelProviderId {
193 LanguageModelProviderId(PROVIDER_ID.into())
194 }
195
196 fn provider_name(&self) -> LanguageModelProviderName {
197 LanguageModelProviderName(PROVIDER_NAME.into())
198 }
199
200 fn supports_tools(&self) -> bool {
201 self.model.supports_tools()
202 }
203
204 fn supports_images(&self) -> bool {
205 self.model.supports_vision()
206 }
207
208 fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
209 match self.model.vendor() {
210 ModelVendor::OpenAI | ModelVendor::Anthropic => {
211 LanguageModelToolSchemaFormat::JsonSchema
212 }
213 ModelVendor::Google => LanguageModelToolSchemaFormat::JsonSchemaSubset,
214 }
215 }
216
217 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
218 match choice {
219 LanguageModelToolChoice::Auto
220 | LanguageModelToolChoice::Any
221 | LanguageModelToolChoice::None => self.supports_tools(),
222 }
223 }
224
225 fn telemetry_id(&self) -> String {
226 format!("copilot_chat/{}", self.model.id())
227 }
228
229 fn max_token_count(&self) -> usize {
230 self.model.max_token_count()
231 }
232
233 fn count_tokens(
234 &self,
235 request: LanguageModelRequest,
236 cx: &App,
237 ) -> BoxFuture<'static, Result<usize>> {
238 match self.model.vendor() {
239 ModelVendor::Anthropic => count_anthropic_tokens(request, cx),
240 ModelVendor::Google => count_google_tokens(request, cx),
241 ModelVendor::OpenAI => {
242 let model = open_ai::Model::from_id(self.model.id()).unwrap_or_default();
243 count_open_ai_tokens(request, model, cx)
244 }
245 }
246 }
247
248 fn stream_completion(
249 &self,
250 request: LanguageModelRequest,
251 cx: &AsyncApp,
252 ) -> BoxFuture<
253 'static,
254 Result<
255 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
256 >,
257 > {
258 if let Some(message) = request.messages.last() {
259 if message.contents_empty() {
260 const EMPTY_PROMPT_MSG: &str =
261 "Empty prompts aren't allowed. Please provide a non-empty prompt.";
262 return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed();
263 }
264
265 // Copilot Chat has a restriction that the final message must be from the user.
266 // While their API does return an error message for this, we can catch it earlier
267 // and provide a more helpful error message.
268 if !matches!(message.role, Role::User) {
269 const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt.";
270 return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed();
271 }
272 }
273
274 let copilot_request = match into_copilot_chat(&self.model, request) {
275 Ok(request) => request,
276 Err(err) => return futures::future::ready(Err(err)).boxed(),
277 };
278 let is_streaming = copilot_request.stream;
279
280 let request_limiter = self.request_limiter.clone();
281 let future = cx.spawn(async move |cx| {
282 let request = CopilotChat::stream_completion(copilot_request, cx.clone());
283 request_limiter
284 .stream(async move {
285 let response = request.await?;
286 Ok(map_to_language_model_completion_events(
287 response,
288 is_streaming,
289 ))
290 })
291 .await
292 });
293 async move { Ok(future.await?.boxed()) }.boxed()
294 }
295}
296
297pub fn map_to_language_model_completion_events(
298 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
299 is_streaming: bool,
300) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
301 #[derive(Default)]
302 struct RawToolCall {
303 id: String,
304 name: String,
305 arguments: String,
306 }
307
308 struct State {
309 events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
310 tool_calls_by_index: HashMap<usize, RawToolCall>,
311 }
312
313 futures::stream::unfold(
314 State {
315 events,
316 tool_calls_by_index: HashMap::default(),
317 },
318 move |mut state| async move {
319 if let Some(event) = state.events.next().await {
320 match event {
321 Ok(event) => {
322 let Some(choice) = event.choices.first() else {
323 return Some((
324 vec![Err(anyhow!("Response contained no choices").into())],
325 state,
326 ));
327 };
328
329 let delta = if is_streaming {
330 choice.delta.as_ref()
331 } else {
332 choice.message.as_ref()
333 };
334
335 let Some(delta) = delta else {
336 return Some((
337 vec![Err(anyhow!("Response contained no delta").into())],
338 state,
339 ));
340 };
341
342 let mut events = Vec::new();
343 if let Some(content) = delta.content.clone() {
344 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
345 }
346
347 for tool_call in &delta.tool_calls {
348 let entry = state
349 .tool_calls_by_index
350 .entry(tool_call.index)
351 .or_default();
352
353 if let Some(tool_id) = tool_call.id.clone() {
354 entry.id = tool_id;
355 }
356
357 if let Some(function) = tool_call.function.as_ref() {
358 if let Some(name) = function.name.clone() {
359 entry.name = name;
360 }
361
362 if let Some(arguments) = function.arguments.clone() {
363 entry.arguments.push_str(&arguments);
364 }
365 }
366 }
367
368 match choice.finish_reason.as_deref() {
369 Some("stop") => {
370 events.push(Ok(LanguageModelCompletionEvent::Stop(
371 StopReason::EndTurn,
372 )));
373 }
374 Some("tool_calls") => {
375 events.extend(state.tool_calls_by_index.drain().map(
376 |(_, tool_call)| {
377 // The model can output an empty string
378 // to indicate the absence of arguments.
379 // When that happens, create an empty
380 // object instead.
381 let arguments = if tool_call.arguments.is_empty() {
382 Ok(serde_json::Value::Object(Default::default()))
383 } else {
384 serde_json::Value::from_str(&tool_call.arguments)
385 };
386 match arguments {
387 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
388 LanguageModelToolUse {
389 id: tool_call.id.clone().into(),
390 name: tool_call.name.as_str().into(),
391 is_input_complete: true,
392 input,
393 raw_input: tool_call.arguments.clone(),
394 },
395 )),
396 Err(error) => {
397 Err(LanguageModelCompletionError::BadInputJson {
398 id: tool_call.id.into(),
399 tool_name: tool_call.name.as_str().into(),
400 raw_input: tool_call.arguments.into(),
401 json_parse_error: error.to_string(),
402 })
403 }
404 }
405 },
406 ));
407
408 events.push(Ok(LanguageModelCompletionEvent::Stop(
409 StopReason::ToolUse,
410 )));
411 }
412 Some(stop_reason) => {
413 log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
414 events.push(Ok(LanguageModelCompletionEvent::Stop(
415 StopReason::EndTurn,
416 )));
417 }
418 None => {}
419 }
420
421 return Some((events, state));
422 }
423 Err(err) => return Some((vec![Err(anyhow!(err).into())], state)),
424 }
425 }
426
427 None
428 },
429 )
430 .flat_map(futures::stream::iter)
431}
432
433fn into_copilot_chat(
434 model: &copilot::copilot_chat::Model,
435 request: LanguageModelRequest,
436) -> Result<CopilotChatRequest> {
437 let mut request_messages: Vec<LanguageModelRequestMessage> = Vec::new();
438 for message in request.messages {
439 if let Some(last_message) = request_messages.last_mut() {
440 if last_message.role == message.role {
441 last_message.content.extend(message.content);
442 } else {
443 request_messages.push(message);
444 }
445 } else {
446 request_messages.push(message);
447 }
448 }
449
450 let mut tool_called = false;
451 let mut messages: Vec<ChatMessage> = Vec::new();
452 for message in request_messages {
453 match message.role {
454 Role::User => {
455 for content in &message.content {
456 if let MessageContent::ToolResult(tool_result) = content {
457 let content = match &tool_result.content {
458 LanguageModelToolResultContent::Text(text) => text.to_string().into(),
459 LanguageModelToolResultContent::Image(image) => {
460 if model.supports_vision() {
461 ChatMessageContent::Multipart(vec![ChatMessagePart::Image {
462 image_url: ImageUrl {
463 url: image.to_base64_url(),
464 },
465 }])
466 } else {
467 debug_panic!(
468 "This should be caught at {} level",
469 tool_result.tool_name
470 );
471 "[Tool responded with an image, but this model does not support vision]".to_string().into()
472 }
473 }
474 };
475
476 messages.push(ChatMessage::Tool {
477 tool_call_id: tool_result.tool_use_id.to_string(),
478 content,
479 });
480 }
481 }
482
483 let mut content_parts = Vec::new();
484 for content in &message.content {
485 match content {
486 MessageContent::Text(text) | MessageContent::Thinking { text, .. }
487 if !text.is_empty() =>
488 {
489 if let Some(ChatMessagePart::Text { text: text_content }) =
490 content_parts.last_mut()
491 {
492 text_content.push_str(text);
493 } else {
494 content_parts.push(ChatMessagePart::Text {
495 text: text.to_string(),
496 });
497 }
498 }
499 MessageContent::Image(image) if model.supports_vision() => {
500 content_parts.push(ChatMessagePart::Image {
501 image_url: ImageUrl {
502 url: image.to_base64_url(),
503 },
504 });
505 }
506 _ => {}
507 }
508 }
509
510 if !content_parts.is_empty() {
511 messages.push(ChatMessage::User {
512 content: content_parts.into(),
513 });
514 }
515 }
516 Role::Assistant => {
517 let mut tool_calls = Vec::new();
518 for content in &message.content {
519 if let MessageContent::ToolUse(tool_use) = content {
520 tool_called = true;
521 tool_calls.push(ToolCall {
522 id: tool_use.id.to_string(),
523 content: copilot::copilot_chat::ToolCallContent::Function {
524 function: copilot::copilot_chat::FunctionContent {
525 name: tool_use.name.to_string(),
526 arguments: serde_json::to_string(&tool_use.input)?,
527 },
528 },
529 });
530 }
531 }
532
533 let text_content = {
534 let mut buffer = String::new();
535 for string in message.content.iter().filter_map(|content| match content {
536 MessageContent::Text(text) | MessageContent::Thinking { text, .. } => {
537 Some(text.as_str())
538 }
539 MessageContent::ToolUse(_)
540 | MessageContent::RedactedThinking(_)
541 | MessageContent::ToolResult(_)
542 | MessageContent::Image(_) => None,
543 }) {
544 buffer.push_str(string);
545 }
546
547 buffer
548 };
549
550 messages.push(ChatMessage::Assistant {
551 content: if text_content.is_empty() {
552 ChatMessageContent::empty()
553 } else {
554 text_content.into()
555 },
556 tool_calls,
557 });
558 }
559 Role::System => messages.push(ChatMessage::System {
560 content: message.string_contents(),
561 }),
562 }
563 }
564
565 let mut tools = request
566 .tools
567 .iter()
568 .map(|tool| Tool::Function {
569 function: copilot::copilot_chat::Function {
570 name: tool.name.clone(),
571 description: tool.description.clone(),
572 parameters: tool.input_schema.clone(),
573 },
574 })
575 .collect::<Vec<_>>();
576
577 // The API will return a Bad Request (with no error message) when tools
578 // were used previously in the conversation but no tools are provided as
579 // part of this request. Inserting a dummy tool seems to circumvent this
580 // error.
581 if tool_called && tools.is_empty() {
582 tools.push(Tool::Function {
583 function: copilot::copilot_chat::Function {
584 name: "noop".to_string(),
585 description: "No operation".to_string(),
586 parameters: serde_json::json!({
587 "type": "object"
588 }),
589 },
590 });
591 }
592
593 Ok(CopilotChatRequest {
594 intent: true,
595 n: 1,
596 stream: model.uses_streaming(),
597 temperature: 0.1,
598 model: model.id().to_string(),
599 messages,
600 tools,
601 tool_choice: request.tool_choice.map(|choice| match choice {
602 LanguageModelToolChoice::Auto => copilot::copilot_chat::ToolChoice::Auto,
603 LanguageModelToolChoice::Any => copilot::copilot_chat::ToolChoice::Any,
604 LanguageModelToolChoice::None => copilot::copilot_chat::ToolChoice::None,
605 }),
606 })
607}
608
609struct ConfigurationView {
610 copilot_status: Option<copilot::Status>,
611 state: Entity<State>,
612 _subscription: Option<Subscription>,
613}
614
615impl ConfigurationView {
616 pub fn new(state: Entity<State>, cx: &mut Context<Self>) -> Self {
617 let copilot = Copilot::global(cx);
618
619 Self {
620 copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()),
621 state,
622 _subscription: copilot.as_ref().map(|copilot| {
623 cx.observe(copilot, |this, model, cx| {
624 this.copilot_status = Some(model.read(cx).status());
625 cx.notify();
626 })
627 }),
628 }
629 }
630}
631
632impl Render for ConfigurationView {
633 fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
634 if self.state.read(cx).is_authenticated(cx) {
635 h_flex()
636 .mt_1()
637 .p_1()
638 .justify_between()
639 .rounded_md()
640 .border_1()
641 .border_color(cx.theme().colors().border)
642 .bg(cx.theme().colors().background)
643 .child(
644 h_flex()
645 .gap_1()
646 .child(Icon::new(IconName::Check).color(Color::Success))
647 .child(Label::new("Authorized")),
648 )
649 .child(
650 Button::new("sign_out", "Sign Out")
651 .label_size(LabelSize::Small)
652 .on_click(|_, window, cx| {
653 window.dispatch_action(copilot::SignOut.boxed_clone(), cx);
654 }),
655 )
656 } else {
657 let loading_icon = Icon::new(IconName::ArrowCircle).with_animation(
658 "arrow-circle",
659 Animation::new(Duration::from_secs(4)).repeat(),
660 |icon, delta| icon.transform(Transformation::rotate(percentage(delta))),
661 );
662
663 const ERROR_LABEL: &str = "Copilot Chat requires an active GitHub Copilot subscription. Please ensure Copilot is configured and try again, or use a different Assistant provider.";
664
665 match &self.copilot_status {
666 Some(status) => match status {
667 Status::Starting { task: _ } => h_flex()
668 .gap_2()
669 .child(loading_icon)
670 .child(Label::new("Starting Copilot…")),
671 Status::SigningIn { prompt: _ }
672 | Status::SignedOut {
673 awaiting_signing_in: true,
674 } => h_flex()
675 .gap_2()
676 .child(loading_icon)
677 .child(Label::new("Signing into Copilot…")),
678 Status::Error(_) => {
679 const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot.";
680 v_flex()
681 .gap_6()
682 .child(Label::new(LABEL))
683 .child(svg().size_8().path(IconName::CopilotError.path()))
684 }
685 _ => {
686 const LABEL: &str = "To use Zed's assistant with GitHub Copilot, you need to be logged in to GitHub. Note that your GitHub account must have an active Copilot Chat subscription.";
687 v_flex().gap_2().child(Label::new(LABEL)).child(
688 Button::new("sign_in", "Sign in to use GitHub Copilot")
689 .icon_color(Color::Muted)
690 .icon(IconName::Github)
691 .icon_position(IconPosition::Start)
692 .icon_size(IconSize::Medium)
693 .full_width()
694 .on_click(|_, window, cx| copilot::initiate_sign_in(window, cx)),
695 )
696 }
697 },
698 None => v_flex().gap_6().child(Label::new(ERROR_LABEL)),
699 }
700 }
701 }
702}