1use crate::{
2 assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
3 LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use std::{future, sync::Arc};
10use ui::prelude::*;
11
12pub struct ZedDotDevCompletionProvider {
13 client: Arc<Client>,
14 default_model: ZedDotDevModel,
15 settings_version: usize,
16 status: client::Status,
17 _maintain_client_status: Task<()>,
18}
19
20impl ZedDotDevCompletionProvider {
21 pub fn new(
22 default_model: ZedDotDevModel,
23 client: Arc<Client>,
24 settings_version: usize,
25 cx: &mut AppContext,
26 ) -> Self {
27 let mut status_rx = client.status();
28 let status = *status_rx.borrow();
29 let maintain_client_status = cx.spawn(|mut cx| async move {
30 while let Some(status) = status_rx.next().await {
31 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
32 if let CompletionProvider::ZedDotDev(provider) = provider {
33 provider.status = status;
34 } else {
35 unreachable!()
36 }
37 });
38 }
39 });
40 Self {
41 client,
42 default_model,
43 settings_version,
44 status,
45 _maintain_client_status: maintain_client_status,
46 }
47 }
48
49 pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
50 self.default_model = default_model;
51 self.settings_version = settings_version;
52 }
53
54 pub fn settings_version(&self) -> usize {
55 self.settings_version
56 }
57
58 pub fn default_model(&self) -> ZedDotDevModel {
59 self.default_model.clone()
60 }
61
62 pub fn is_authenticated(&self) -> bool {
63 self.status.is_connected()
64 }
65
66 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
67 let client = self.client.clone();
68 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
69 }
70
71 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
72 cx.new_view(|_cx| AuthenticationPrompt).into()
73 }
74
75 pub fn count_tokens(
76 &self,
77 request: LanguageModelRequest,
78 cx: &AppContext,
79 ) -> BoxFuture<'static, Result<usize>> {
80 match request.model {
81 crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
82 crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
83 | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
84 | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
85 count_open_ai_tokens(request, cx.background_executor())
86 }
87 crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
88 let request = self.client.request(proto::CountTokensWithLanguageModel {
89 model,
90 messages: request
91 .messages
92 .iter()
93 .map(|message| message.to_proto())
94 .collect(),
95 });
96 async move {
97 let response = request.await?;
98 Ok(response.token_count as usize)
99 }
100 .boxed()
101 }
102 }
103 }
104
105 pub fn complete(
106 &self,
107 request: LanguageModelRequest,
108 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
109 let request = proto::CompleteWithLanguageModel {
110 model: request.model.id().to_string(),
111 messages: request
112 .messages
113 .iter()
114 .map(|message| message.to_proto())
115 .collect(),
116 stop: request.stop,
117 temperature: request.temperature,
118 };
119
120 self.client
121 .request_stream(request)
122 .map_ok(|stream| {
123 stream
124 .filter_map(|response| async move {
125 match response {
126 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
127 Err(error) => Some(Err(error)),
128 }
129 })
130 .boxed()
131 })
132 .boxed()
133 }
134}
135
136struct AuthenticationPrompt;
137
138impl Render for AuthenticationPrompt {
139 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
140 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
141
142 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
143 v_flex()
144 .gap_2()
145 .child(
146 Button::new("sign_in", "Sign in")
147 .icon_color(Color::Muted)
148 .icon(IconName::Github)
149 .icon_position(IconPosition::Start)
150 .style(ButtonStyle::Filled)
151 .full_width()
152 .on_click(|_, cx| {
153 CompletionProvider::global(cx)
154 .authenticate(cx)
155 .detach_and_log_err(cx);
156 }),
157 )
158 .child(
159 div().flex().w_full().items_center().child(
160 Label::new("Sign in to enable collaboration.")
161 .color(Color::Muted)
162 .size(LabelSize::Small),
163 ),
164 ),
165 )
166 }
167}