1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context;
5pub mod context_store;
6mod inline_assistant;
7mod model_selector;
8mod prompt_library;
9mod prompts;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub use completion_provider::*;
20pub use context::*;
21pub use context_store::*;
22use fs::Fs;
23use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
24use indexed_docs::IndexedDocsRegistry;
25pub(crate) use inline_assistant::*;
26pub(crate) use model_selector::*;
27use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
28use serde::{Deserialize, Serialize};
29use settings::{Settings, SettingsStore};
30use slash_command::{
31 active_command, default_command, diagnostics_command, docs_command, fetch_command,
32 file_command, now_command, project_command, prompt_command, search_command, symbols_command,
33 tabs_command, term_command,
34};
35use std::{
36 fmt::{self, Display},
37 sync::Arc,
38};
39pub(crate) use streaming_diff::*;
40
41actions!(
42 assistant,
43 [
44 Assist,
45 Split,
46 CycleMessageRole,
47 QuoteSelection,
48 InsertIntoEditor,
49 ToggleFocus,
50 ResetKey,
51 InlineAssist,
52 InsertActivePrompt,
53 DeployHistory,
54 DeployPromptLibrary,
55 ConfirmCommand,
56 ToggleModelSelector,
57 DebugEditSteps
58 ]
59);
60
61#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
62pub struct MessageId(clock::Lamport);
63
64impl MessageId {
65 pub fn as_u64(self) -> u64 {
66 self.0.as_u64()
67 }
68}
69
70#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
71#[serde(rename_all = "lowercase")]
72pub enum Role {
73 User,
74 Assistant,
75 System,
76}
77
78impl Role {
79 pub fn from_proto(role: i32) -> Role {
80 match proto::LanguageModelRole::from_i32(role) {
81 Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
82 Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
83 Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
84 Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
85 None => Role::User,
86 }
87 }
88
89 pub fn to_proto(&self) -> proto::LanguageModelRole {
90 match self {
91 Role::User => proto::LanguageModelRole::LanguageModelUser,
92 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
93 Role::System => proto::LanguageModelRole::LanguageModelSystem,
94 }
95 }
96
97 pub fn cycle(self) -> Role {
98 match self {
99 Role::User => Role::Assistant,
100 Role::Assistant => Role::System,
101 Role::System => Role::User,
102 }
103 }
104}
105
106impl Display for Role {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 Role::User => write!(f, "user"),
110 Role::Assistant => write!(f, "assistant"),
111 Role::System => write!(f, "system"),
112 }
113 }
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
117pub enum LanguageModel {
118 Cloud(CloudModel),
119 OpenAi(OpenAiModel),
120 Anthropic(AnthropicModel),
121 Ollama(OllamaModel),
122}
123
124impl Default for LanguageModel {
125 fn default() -> Self {
126 LanguageModel::Cloud(CloudModel::default())
127 }
128}
129
130impl LanguageModel {
131 pub fn telemetry_id(&self) -> String {
132 match self {
133 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
134 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
135 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
136 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
137 }
138 }
139
140 pub fn display_name(&self) -> String {
141 match self {
142 LanguageModel::OpenAi(model) => model.display_name().into(),
143 LanguageModel::Anthropic(model) => model.display_name().into(),
144 LanguageModel::Cloud(model) => model.display_name().into(),
145 LanguageModel::Ollama(model) => model.display_name().into(),
146 }
147 }
148
149 pub fn max_token_count(&self) -> usize {
150 match self {
151 LanguageModel::OpenAi(model) => model.max_token_count(),
152 LanguageModel::Anthropic(model) => model.max_token_count(),
153 LanguageModel::Cloud(model) => model.max_token_count(),
154 LanguageModel::Ollama(model) => model.max_token_count(),
155 }
156 }
157
158 pub fn id(&self) -> &str {
159 match self {
160 LanguageModel::OpenAi(model) => model.id(),
161 LanguageModel::Anthropic(model) => model.id(),
162 LanguageModel::Cloud(model) => model.id(),
163 LanguageModel::Ollama(model) => model.id(),
164 }
165 }
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169pub struct LanguageModelRequestMessage {
170 pub role: Role,
171 pub content: String,
172}
173
174impl LanguageModelRequestMessage {
175 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
176 proto::LanguageModelRequestMessage {
177 role: self.role.to_proto() as i32,
178 content: self.content.clone(),
179 tool_calls: Vec::new(),
180 tool_call_id: None,
181 }
182 }
183}
184
185#[derive(Debug, Default, Serialize, Deserialize)]
186pub struct LanguageModelRequest {
187 pub model: LanguageModel,
188 pub messages: Vec<LanguageModelRequestMessage>,
189 pub stop: Vec<String>,
190 pub temperature: f32,
191}
192
193impl LanguageModelRequest {
194 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
195 proto::CompleteWithLanguageModel {
196 model: self.model.id().to_string(),
197 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
198 stop: self.stop.clone(),
199 temperature: self.temperature,
200 tool_choice: None,
201 tools: Vec::new(),
202 }
203 }
204
205 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
206 pub fn preprocess(&mut self) {
207 match &self.model {
208 LanguageModel::OpenAi(_) => {}
209 LanguageModel::Anthropic(_) => {}
210 LanguageModel::Ollama(_) => {}
211 LanguageModel::Cloud(model) => match model {
212 CloudModel::Claude3Opus
213 | CloudModel::Claude3Sonnet
214 | CloudModel::Claude3Haiku
215 | CloudModel::Claude3_5Sonnet => {
216 preprocess_anthropic_request(self);
217 }
218 _ => {}
219 },
220 }
221 }
222}
223
224#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
225pub struct LanguageModelResponseMessage {
226 pub role: Option<Role>,
227 pub content: Option<String>,
228}
229
230#[derive(Deserialize, Debug)]
231pub struct LanguageModelUsage {
232 pub prompt_tokens: u32,
233 pub completion_tokens: u32,
234 pub total_tokens: u32,
235}
236
237#[derive(Deserialize, Debug)]
238pub struct LanguageModelChoiceDelta {
239 pub index: u32,
240 pub delta: LanguageModelResponseMessage,
241 pub finish_reason: Option<String>,
242}
243
244#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
245pub enum MessageStatus {
246 Pending,
247 Done,
248 Error(SharedString),
249}
250
251impl MessageStatus {
252 pub fn from_proto(status: proto::ContextMessageStatus) -> MessageStatus {
253 match status.variant {
254 Some(proto::context_message_status::Variant::Pending(_)) => MessageStatus::Pending,
255 Some(proto::context_message_status::Variant::Done(_)) => MessageStatus::Done,
256 Some(proto::context_message_status::Variant::Error(error)) => {
257 MessageStatus::Error(error.message.into())
258 }
259 None => MessageStatus::Pending,
260 }
261 }
262
263 pub fn to_proto(&self) -> proto::ContextMessageStatus {
264 match self {
265 MessageStatus::Pending => proto::ContextMessageStatus {
266 variant: Some(proto::context_message_status::Variant::Pending(
267 proto::context_message_status::Pending {},
268 )),
269 },
270 MessageStatus::Done => proto::ContextMessageStatus {
271 variant: Some(proto::context_message_status::Variant::Done(
272 proto::context_message_status::Done {},
273 )),
274 },
275 MessageStatus::Error(message) => proto::ContextMessageStatus {
276 variant: Some(proto::context_message_status::Variant::Error(
277 proto::context_message_status::Error {
278 message: message.to_string(),
279 },
280 )),
281 },
282 }
283 }
284}
285
286/// The state pertaining to the Assistant.
287#[derive(Default)]
288struct Assistant {
289 /// Whether the Assistant is enabled.
290 enabled: bool,
291}
292
293impl Global for Assistant {}
294
295impl Assistant {
296 const NAMESPACE: &'static str = "assistant";
297
298 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
299 if self.enabled == enabled {
300 return;
301 }
302
303 self.enabled = enabled;
304
305 if !enabled {
306 CommandPaletteFilter::update_global(cx, |filter, _cx| {
307 filter.hide_namespace(Self::NAMESPACE);
308 });
309
310 return;
311 }
312
313 CommandPaletteFilter::update_global(cx, |filter, _cx| {
314 filter.show_namespace(Self::NAMESPACE);
315 });
316 }
317}
318
319pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
320 cx.set_global(Assistant::default());
321 AssistantSettings::register(cx);
322
323 cx.spawn(|mut cx| {
324 let client = client.clone();
325 async move {
326 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
327 let semantic_index = SemanticIndex::new(
328 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
329 Arc::new(embedding_provider),
330 &mut cx,
331 )
332 .await?;
333 cx.update(|cx| cx.set_global(semantic_index))
334 }
335 })
336 .detach();
337
338 context_store::init(&client);
339 prompt_library::init(cx);
340 completion_provider::init(client.clone(), cx);
341 assistant_slash_command::init(cx);
342 register_slash_commands(cx);
343 assistant_panel::init(cx);
344 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
345 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
346 IndexedDocsRegistry::init_global(cx);
347
348 CommandPaletteFilter::update_global(cx, |filter, _cx| {
349 filter.hide_namespace(Assistant::NAMESPACE);
350 });
351 Assistant::update_global(cx, |assistant, cx| {
352 let settings = AssistantSettings::get_global(cx);
353
354 assistant.set_enabled(settings.enabled, cx);
355 });
356 cx.observe_global::<SettingsStore>(|cx| {
357 Assistant::update_global(cx, |assistant, cx| {
358 let settings = AssistantSettings::get_global(cx);
359 assistant.set_enabled(settings.enabled, cx);
360 });
361 })
362 .detach();
363}
364
365fn register_slash_commands(cx: &mut AppContext) {
366 let slash_command_registry = SlashCommandRegistry::global(cx);
367 slash_command_registry.register_command(file_command::FileSlashCommand, true);
368 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
369 slash_command_registry.register_command(symbols_command::OutlineSlashCommand, true);
370 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
371 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
372 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
373 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
374 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
375 slash_command_registry.register_command(term_command::TermSlashCommand, true);
376 slash_command_registry.register_command(now_command::NowSlashCommand, true);
377 slash_command_registry.register_command(diagnostics_command::DiagnosticsSlashCommand, true);
378 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
379 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
380}
381
382pub fn humanize_token_count(count: usize) -> String {
383 match count {
384 0..=999 => count.to_string(),
385 1000..=9999 => {
386 let thousands = count / 1000;
387 let hundreds = (count % 1000 + 50) / 100;
388 if hundreds == 0 {
389 format!("{}k", thousands)
390 } else if hundreds == 10 {
391 format!("{}k", thousands + 1)
392 } else {
393 format!("{}.{}k", thousands, hundreds)
394 }
395 }
396 _ => format!("{}k", (count + 500) / 1000),
397 }
398}
399
400#[cfg(test)]
401#[ctor::ctor]
402fn init_logger() {
403 if std::env::var("RUST_LOG").is_ok() {
404 env_logger::init();
405 }
406}