1use crate::{
2 assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
3 LanguageModelRequest,
4};
5use anyhow::{anyhow, Result};
6use client::{proto, Client};
7use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
8use gpui::{AnyView, AppContext, Task};
9use std::{future, sync::Arc};
10use ui::prelude::*;
11
12pub struct ZedDotDevCompletionProvider {
13 client: Arc<Client>,
14 default_model: ZedDotDevModel,
15 settings_version: usize,
16 status: client::Status,
17 _maintain_client_status: Task<()>,
18}
19
20impl ZedDotDevCompletionProvider {
21 pub fn new(
22 default_model: ZedDotDevModel,
23 client: Arc<Client>,
24 settings_version: usize,
25 cx: &mut AppContext,
26 ) -> Self {
27 let mut status_rx = client.status();
28 let status = *status_rx.borrow();
29 let maintain_client_status = cx.spawn(|mut cx| async move {
30 while let Some(status) = status_rx.next().await {
31 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
32 if let CompletionProvider::ZedDotDev(provider) = provider {
33 provider.status = status;
34 } else {
35 unreachable!()
36 }
37 });
38 }
39 });
40 Self {
41 client,
42 default_model,
43 settings_version,
44 status,
45 _maintain_client_status: maintain_client_status,
46 }
47 }
48
49 pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
50 self.default_model = default_model;
51 self.settings_version = settings_version;
52 }
53
54 pub fn settings_version(&self) -> usize {
55 self.settings_version
56 }
57
58 pub fn default_model(&self) -> ZedDotDevModel {
59 self.default_model.clone()
60 }
61
62 pub fn is_authenticated(&self) -> bool {
63 self.status.is_connected()
64 }
65
66 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
67 let client = self.client.clone();
68 cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
69 }
70
71 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
72 cx.new_view(|_cx| AuthenticationPrompt).into()
73 }
74
75 pub fn count_tokens(
76 &self,
77 request: LanguageModelRequest,
78 cx: &AppContext,
79 ) -> BoxFuture<'static, Result<usize>> {
80 match request.model {
81 LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
82 LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
83 | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
84 | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
85 count_open_ai_tokens(request, cx.background_executor())
86 }
87 LanguageModel::ZedDotDev(
88 ZedDotDevModel::Claude3Opus
89 | ZedDotDevModel::Claude3Sonnet
90 | ZedDotDevModel::Claude3Haiku,
91 ) => {
92 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
93 count_open_ai_tokens(request, cx.background_executor())
94 }
95 LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
96 let request = self.client.request(proto::CountTokensWithLanguageModel {
97 model,
98 messages: request
99 .messages
100 .iter()
101 .map(|message| message.to_proto())
102 .collect(),
103 });
104 async move {
105 let response = request.await?;
106 Ok(response.token_count as usize)
107 }
108 .boxed()
109 }
110 }
111 }
112
113 pub fn complete(
114 &self,
115 request: LanguageModelRequest,
116 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
117 let request = proto::CompleteWithLanguageModel {
118 model: request.model.id().to_string(),
119 messages: request
120 .messages
121 .iter()
122 .map(|message| message.to_proto())
123 .collect(),
124 stop: request.stop,
125 temperature: request.temperature,
126 tools: Vec::new(),
127 tool_choice: None,
128 };
129
130 self.client
131 .request_stream(request)
132 .map_ok(|stream| {
133 stream
134 .filter_map(|response| async move {
135 match response {
136 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
137 Err(error) => Some(Err(error)),
138 }
139 })
140 .boxed()
141 })
142 .boxed()
143 }
144}
145
146struct AuthenticationPrompt;
147
148impl Render for AuthenticationPrompt {
149 fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
150 const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
151
152 v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
153 v_flex()
154 .gap_2()
155 .child(
156 Button::new("sign_in", "Sign in")
157 .icon_color(Color::Muted)
158 .icon(IconName::Github)
159 .icon_position(IconPosition::Start)
160 .style(ButtonStyle::Filled)
161 .full_width()
162 .on_click(|_, cx| {
163 CompletionProvider::global(cx)
164 .authenticate(cx)
165 .detach_and_log_err(cx);
166 }),
167 )
168 .child(
169 div().flex().w_full().items_center().child(
170 Label::new("Sign in to enable collaboration.")
171 .color(Color::Muted)
172 .size(LabelSize::Small),
173 ),
174 ),
175 )
176 }
177}