1#![cfg_attr(target_os = "windows", allow(unused, dead_code))]
2
3pub mod assistant_panel;
4pub mod assistant_settings;
5mod context;
6pub(crate) mod context_inspector;
7pub mod context_store;
8mod inline_assistant;
9mod model_selector;
10mod prompt_library;
11mod prompts;
12mod slash_command;
13mod streaming_diff;
14mod terminal_inline_assistant;
15
16pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
17use assistant_settings::AssistantSettings;
18use assistant_slash_command::SlashCommandRegistry;
19use client::{proto, Client};
20use command_palette_hooks::CommandPaletteFilter;
21pub use context::*;
22pub use context_store::*;
23use feature_flags::FeatureFlagAppExt;
24use fs::Fs;
25use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
26use indexed_docs::IndexedDocsRegistry;
27pub(crate) use inline_assistant::*;
28use language_model::{
29 LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
30};
31pub(crate) use model_selector::*;
32pub use prompts::PromptBuilder;
33use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
34use serde::{Deserialize, Serialize};
35use settings::{update_settings_file, Settings, SettingsStore};
36use slash_command::{
37 active_command, default_command, diagnostics_command, docs_command, fetch_command,
38 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
39 tabs_command, term_command, workflow_command,
40};
41use std::sync::Arc;
42pub(crate) use streaming_diff::*;
43use util::ResultExt;
44
45actions!(
46 assistant,
47 [
48 Assist,
49 Split,
50 CycleMessageRole,
51 QuoteSelection,
52 InsertIntoEditor,
53 ToggleFocus,
54 InsertActivePrompt,
55 ShowConfiguration,
56 DeployHistory,
57 DeployPromptLibrary,
58 ConfirmCommand,
59 ToggleModelSelector,
60 DebugWorkflowSteps
61 ]
62);
63
64const DEFAULT_CONTEXT_LINES: usize = 20;
65
66#[derive(Clone, Default, Deserialize, PartialEq)]
67pub struct InlineAssist {
68 prompt: Option<String>,
69}
70
71impl_actions!(assistant, [InlineAssist]);
72
73#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
74pub struct MessageId(clock::Lamport);
75
76impl MessageId {
77 pub fn as_u64(self) -> u64 {
78 self.0.as_u64()
79 }
80}
81
82#[derive(Deserialize, Debug)]
83pub struct LanguageModelUsage {
84 pub prompt_tokens: u32,
85 pub completion_tokens: u32,
86 pub total_tokens: u32,
87}
88
89#[derive(Deserialize, Debug)]
90pub struct LanguageModelChoiceDelta {
91 pub index: u32,
92 pub delta: LanguageModelResponseMessage,
93 pub finish_reason: Option<String>,
94}
95
96#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
97pub enum MessageStatus {
98 Pending,
99 Done,
100 Error(SharedString),
101}
102
103impl MessageStatus {
104 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
105 match status.variant {
106 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
107 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
108 Some(proto::context_message_status::Variant::Error(error)) => {
109 MessageStatus::Error(error.message.into())
110 }
111 None => MessageStatus::Pending,
112 }
113 }
114
115 pub fn to_proto(&self) -> proto::ContextMessageStatus {
116 match self {
117 MessageStatus::Pending => proto::ContextMessageStatus {
118 variant: Some(proto::context_message_status::Variant::Pending(
119 proto::context_message_status::Pending {},
120 )),
121 },
122 MessageStatus::Done => proto::ContextMessageStatus {
123 variant: Some(proto::context_message_status::Variant::Done(
124 proto::context_message_status::Done {},
125 )),
126 },
127 MessageStatus::Error(message) => proto::ContextMessageStatus {
128 variant: Some(proto::context_message_status::Variant::Error(
129 proto::context_message_status::Error {
130 message: message.to_string(),
131 },
132 )),
133 },
134 }
135 }
136}
137
138/// The state pertaining to the Assistant.
139#[derive(Default)]
140struct Assistant {
141 /// Whether the Assistant is enabled.
142 enabled: bool,
143}
144
145impl Global for Assistant {}
146
147impl Assistant {
148 const NAMESPACE: &'static str = "assistant";
149
150 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
151 if self.enabled == enabled {
152 return;
153 }
154
155 self.enabled = enabled;
156
157 if !enabled {
158 CommandPaletteFilter::update_global(cx, |filter, _cx| {
159 filter.hide_namespace(Self::NAMESPACE);
160 });
161
162 return;
163 }
164
165 CommandPaletteFilter::update_global(cx, |filter, _cx| {
166 filter.show_namespace(Self::NAMESPACE);
167 });
168 }
169}
170
171pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) -> Arc<PromptBuilder> {
172 cx.set_global(Assistant::default());
173 AssistantSettings::register(cx);
174
175 // TODO: remove this when 0.148.0 is released.
176 if AssistantSettings::get_global(cx).using_outdated_settings_version {
177 update_settings_file::<AssistantSettings>(fs.clone(), cx, {
178 let fs = fs.clone();
179 |content, cx| {
180 content.update_file(fs, cx);
181 }
182 });
183 }
184
185 cx.spawn(|mut cx| {
186 let client = client.clone();
187 async move {
188 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
189 let semantic_index = SemanticIndex::new(
190 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
191 Arc::new(embedding_provider),
192 &mut cx,
193 )
194 .await?;
195 cx.update(|cx| cx.set_global(semantic_index))
196 }
197 })
198 .detach();
199
200 context_store::init(&client);
201 prompt_library::init(cx);
202 init_language_model_settings(cx);
203 assistant_slash_command::init(cx);
204 assistant_panel::init(cx);
205
206 let prompt_builder = prompts::PromptBuilder::new(Some((fs.clone(), cx)))
207 .log_err()
208 .map(Arc::new)
209 .unwrap_or_else(|| Arc::new(prompts::PromptBuilder::new(None).unwrap()));
210 register_slash_commands(Some(prompt_builder.clone()), cx);
211 inline_assistant::init(
212 fs.clone(),
213 prompt_builder.clone(),
214 client.telemetry().clone(),
215 cx,
216 );
217 terminal_inline_assistant::init(
218 fs.clone(),
219 prompt_builder.clone(),
220 client.telemetry().clone(),
221 cx,
222 );
223 IndexedDocsRegistry::init_global(cx);
224
225 CommandPaletteFilter::update_global(cx, |filter, _cx| {
226 filter.hide_namespace(Assistant::NAMESPACE);
227 });
228 Assistant::update_global(cx, |assistant, cx| {
229 let settings = AssistantSettings::get_global(cx);
230
231 assistant.set_enabled(settings.enabled, cx);
232 });
233 cx.observe_global::<SettingsStore>(|cx| {
234 Assistant::update_global(cx, |assistant, cx| {
235 let settings = AssistantSettings::get_global(cx);
236 assistant.set_enabled(settings.enabled, cx);
237 });
238 })
239 .detach();
240
241 prompt_builder
242}
243
244fn init_language_model_settings(cx: &mut AppContext) {
245 update_active_language_model_from_settings(cx);
246
247 cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
248 .detach();
249 cx.subscribe(
250 &LanguageModelRegistry::global(cx),
251 |_, event: &language_model::Event, cx| match event {
252 language_model::Event::ProviderStateChanged
253 | language_model::Event::AddedProvider(_)
254 | language_model::Event::RemovedProvider(_) => {
255 update_active_language_model_from_settings(cx);
256 }
257 _ => {}
258 },
259 )
260 .detach();
261}
262
263fn update_active_language_model_from_settings(cx: &mut AppContext) {
264 let settings = AssistantSettings::get_global(cx);
265 let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
266 let model_id = LanguageModelId::from(settings.default_model.model.clone());
267 LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
268 registry.select_active_model(&provider_name, &model_id, cx);
269 });
270}
271
272fn register_slash_commands(prompt_builder: Option<Arc<PromptBuilder>>, cx: &mut AppContext) {
273 let slash_command_registry = SlashCommandRegistry::global(cx);
274 slash_command_registry.register_command(file_command::FileSlashCommand, true);
275 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
276 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
277 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
278 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
279 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
280 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
281 slash_command_registry.register_command(term_command::TermSlashCommand, true);
282 slash_command_registry.register_command(now_command::NowSlashCommand, true);
283 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
284 if let Some(prompt_builder) = prompt_builder {
285 slash_command_registry.register_command(
286 workflow_command::WorkflowSlashCommand::new(prompt_builder),
287 true,
288 );
289 }
290 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
291
292 cx.observe_flag::<docs_command::DocsSlashCommandFeatureFlag, _>({
293 let slash_command_registry = slash_command_registry.clone();
294 move |is_enabled, _cx| {
295 if is_enabled {
296 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
297 }
298 }
299 })
300 .detach();
301 cx.observe_flag::<search_command::SearchSlashCommandFeatureFlag, _>({
302 let slash_command_registry = slash_command_registry.clone();
303 move |is_enabled, _cx| {
304 if is_enabled {
305 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
306 }
307 }
308 })
309 .detach();
310}
311
312pub fn humanize_token_count(count: usize) -> String {
313 match count {
314 0..=999 => count.to_string(),
315 1000..=9999 => {
316 let thousands = count / 1000;
317 let hundreds = (count % 1000 + 50) / 100;
318 if hundreds == 0 {
319 format!("{}k", thousands)
320 } else if hundreds == 10 {
321 format!("{}k", thousands + 1)
322 } else {
323 format!("{}.{}k", thousands, hundreds)
324 }
325 }
326 _ => format!("{}k", (count + 500) / 1000),
327 }
328}
329
330#[cfg(test)]
331#[ctor::ctor]
332fn init_logger() {
333 if std::env::var("RUST_LOG").is_ok() {
334 env_logger::init();
335 }
336}