1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub mod context_store;
7mod inline_assistant;
8mod model_selector;
9mod prompt_library;
10mod prompts;
11mod slash_command;
12mod streaming_diff;
13mod terminal_inline_assistant;
14
15pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
16use assistant_settings::AssistantSettings;
17use assistant_slash_command::SlashCommandRegistry;
18use client::{proto, Client};
19use command_palette_hooks::CommandPaletteFilter;
20pub use context::*;
21pub use context_store::*;
22use feature_flags::FeatureFlagAppExt;
23use fs::Fs;
24use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
25use indexed_docs::IndexedDocsRegistry;
26pub(crate) use inline_assistant::*;
27use language_model::{
28 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
29};
30pub(crate) use model_selector::*;
31pub use prompts::PromptBuilder;
32use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
33use serde::{Deserialize, Serialize};
34use settings::{update_settings_file, Settings, SettingsStore};
35use slash_command::{
36 active_command, default_command, diagnostics_command, docs_command, fetch_command,
37 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
38 tabs_command, term_command, workflow_command,
39};
40use std::sync::Arc;
41pub(crate) use streaming_diff::*;
42use util::ResultExt;
43
44actions!(
45 assistant,
46 [
47 Assist,
48 Split,
49 CycleMessageRole,
50 QuoteSelection,
51 InsertIntoEditor,
52 ToggleFocus,
53 InsertActivePrompt,
54 ShowConfiguration,
55 DeployHistory,
56 DeployPromptLibrary,
57 ConfirmCommand,
58 ToggleModelSelector,
59 DebugWorkflowSteps
60 ]
61);
62
63const DEFAULT_CONTEXT_LINES: usize = 20;
64
65#[derive(Clone, Default, Deserialize, PartialEq)]
66pub struct InlineAssist {
67 prompt: Option<String>,
68}
69
70impl_actions!(assistant, [InlineAssist]);
71
72#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
73pub struct MessageId(clock::Lamport);
74
75impl MessageId {
76 pub fn as_u64(self) -> u64 {
77 self.0.as_u64()
78 }
79}
80
81#[derive(Deserialize, Debug)]
82pub struct LanguageModelUsage {
83 pub prompt_tokens: u32,
84 pub completion_tokens: u32,
85 pub total_tokens: u32,
86}
87
88#[derive(Deserialize, Debug)]
89pub struct LanguageModelChoiceDelta {
90 pub index: u32,
91 pub delta: LanguageModelResponseMessage,
92 pub finish_reason: Option<String>,
93}
94
95#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
96pub enum MessageStatus {
97 Pending,
98 Done,
99 Error(SharedString),
100}
101
102impl MessageStatus {
103 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
104 match status.variant {
105 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
106 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
107 Some(proto::context_message_status::Variant::Error(error)) => {
108 MessageStatus::Error(error.message.into())
109 }
110 None => MessageStatus::Pending,
111 }
112 }
113
114 pub fn to_proto(&self) -> proto::ContextMessageStatus {
115 match self {
116 MessageStatus::Pending => proto::ContextMessageStatus {
117 variant: Some(proto::context_message_status::Variant::Pending(
118 proto::context_message_status::Pending {},
119 )),
120 },
121 MessageStatus::Done => proto::ContextMessageStatus {
122 variant: Some(proto::context_message_status::Variant::Done(
123 proto::context_message_status::Done {},
124 )),
125 },
126 MessageStatus::Error(message) => proto::ContextMessageStatus {
127 variant: Some(proto::context_message_status::Variant::Error(
128 proto::context_message_status::Error {
129 message: message.to_string(),
130 },
131 )),
132 },
133 }
134 }
135}
136
137/// The state pertaining to the Assistant.
138#[derive(Default)]
139struct Assistant {
140 /// Whether the Assistant is enabled.
141 enabled: bool,
142}
143
144impl Global for Assistant {}
145
146impl Assistant {
147 const NAMESPACE: &'static str = "assistant";
148
149 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
150 if self.enabled == enabled {
151 return;
152 }
153
154 self.enabled = enabled;
155
156 if !enabled {
157 CommandPaletteFilter::update_global(cx, |filter, _cx| {
158 filter.hide_namespace(Self::NAMESPACE);
159 });
160
161 return;
162 }
163
164 CommandPaletteFilter::update_global(cx, |filter, _cx| {
165 filter.show_namespace(Self::NAMESPACE);
166 });
167 }
168}
169
170pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
171 cx.set_global(Assistant::default());
172 AssistantSettings::register(cx);
173
174 // TODO: remove this when 0.148.0 is released.
175 if AssistantSettings::get_global(cx).using_outdated_settings_version {
176 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
177 let fs = fs.clone();
178 |content, cx| {
179 content.update_file(fs, cx);
180 }
181 });
182 }
183
184 cx.spawn(|mut cx| {
185 let client = client.clone();
186 async move {
187 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
188 let semantic_index = SemanticIndex::new(
189 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
190 Arc::new(embedding_provider),
191 &mut cx,
192 )
193 .await?;
194 cx.update(|cx| cx.set_global(semantic_index))
195 }
196 })
197 .detach();
198
199 context_store::init(&client);
200 prompt_library::init(cx);
201 init_language_model_settings(cx);
202 assistant_slash_command::init(cx);
203 assistant_panel::init(cx);
204
205 let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
206 .log_err()
207 .map(Arc::new)
208 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
209 register_slash_commands(Some(prompt_builder.clone()), cx);
210 inline_assistant::init(
211 fs.clone(),
212 prompt_builder.clone(),
213 client.telemetry().clone(),
214 cx,
215 );
216 terminal_inline_assistant::init(
217 fs.clone(),
218 prompt_builder.clone(),
219 client.telemetry().clone(),
220 cx,
221 );
222 IndexedDocsRegistry::init_global(cx);
223
224 CommandPaletteFilter::update_global(cx, |filter, _cx| {
225 filter.hide_namespace(Assistant::NAMESPACE);
226 });
227 Assistant::update_global(cx, |assistant, cx| {
228 let settings = AssistantSettings::get_global(cx);
229
230 assistant.set_enabled(settings.enabled, cx);
231 });
232 cx.observe_global::<SettingsStore>(|cx| {
233 Assistant::update_global(cx, |assistant, cx| {
234 let settings = AssistantSettings::get_global(cx);
235 assistant.set_enabled(settings.enabled, cx);
236 });
237 })
238 .detach();
239
240 prompt_builder
241}
242
243fn init_language_model_settings(cx: &mut AppContext) {
244 update_active_language_model_from_settings(cx);
245
246 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
247 .detach();
248 cx.subscribe(
249 &LanguageModelRegistry::global(cx),
250 |_, event: &language_model::Event, cx| match event {
251 language_model::Event::ProviderStateChanged
252 | language_model::Event::AddedProvider(_)
253 | language_model::Event::RemovedProvider(_) => {
254 update_active_language_model_from_settings(cx);
255 }
256 _ => {}
257 },
258 )
259 .detach();
260}
261
262fn update_active_language_model_from_settings(cx: &mut AppContext) {
263 let settings = AssistantSettings::get_global(cx);
264 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
265 let model_id = LanguageModelId::from(settings.default_model.model.clone());
266 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
267 registry.select_active_model(&provider_name, &model_id, cx);
268 });
269}
270
271fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
272 let slash_command_registry = SlashCommandRegistry::global(cx);
273 slash_command_registry.register_command(file_command::FileSlashCommand, true);
274 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
275 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
276 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
277 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
278 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
279 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
280 slash_command_registry.register_command(term_command::TermSlashCommand, true);
281 slash_command_registry.register_command(now_command::NowSlashCommand, true);
282 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
283 if let Some(prompt_builder) = prompt_builder {
284 slash_command_registry.register_command(
285 workflow_command::WorkflowSlashCommand::new(prompt_builder),
286 true,
287 );
288 }
289 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
290
291 cx.observe_flag::<docs_command::DocsSlashCommandFeatureFlag, _>({
292 let slash_command_registry = slash_command_registry.clone();
293 move |is_enabled, _cx| {
294 if is_enabled {
295 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
296 }
297 }
298 })
299 .detach();
300 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
301 let slash_command_registry = slash_command_registry.clone();
302 move |is_enabled, _cx| {
303 if is_enabled {
304 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
305 }
306 }
307 })
308 .detach();
309}
310
311pub fn humanize_token_count(count: usize) -> String {
312 match count {
313 0..=999 => count.to_string(),
314 1000..=9999 => {
315 let thousands = count / 1000;
316 let hundreds = (count % 1000 + 50) / 100;
317 if hundreds == 0 {
318 format!("{}k", thousands)
319 } else if hundreds == 10 {
320 format!("{}k", thousands + 1)
321 } else {
322 format!("{}.{}k", thousands, hundreds)
323 }
324 }
325 _ => format!("{}k", (count + 500) / 1000),
326 }
327}
328
329#[cfg(test)]
330#[ctor::ctor]
331fn init_logger() {
332 if std::env::var("RUST_LOG").is_ok() {
333 env_logger::init();
334 }
335}