1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12mod terminal_inline_assistant;
13
14pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use fs::Fs;
22use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
23use indexed_docs::IndexedDocsRegistry;
24pub(crate) use inline_assistant::*;
25pub(crate) use model_selector::*;
26use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
27use serde::{Deserialize, Serialize};
28use settings::{Settings, SettingsStore};
29use slash_command::{
30 active_command, default_command, diagnostics_command, docs_command, fetch_command,
31 file_command, now_command, project_command, prompt_command, search_command, tabs_command,
32 term_command,
33};
34use std::{
35 fmt::{self, Display},
36 sync::Arc,
37};
38pub(crate) use streaming_diff::*;
39
40actions!(
41 assistant,
42 [
43 Assist,
44 Split,
45 CycleMessageRole,
46 QuoteSelection,
47 InsertIntoEditor,
48 ToggleFocus,
49 ResetKey,
50 InlineAssist,
51 InsertActivePrompt,
52 ToggleHistory,
53 ApplyEdit,
54 ConfirmCommand,
55 ToggleModelSelector
56 ]
57);
58
59#[derive(
60 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
61)]
62struct MessageId(usize);
63
64#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
65#[serde(rename_all = "lowercase")]
66pub enum Role {
67 User,
68 Assistant,
69 System,
70}
71
72impl Role {
73 pub fn cycle(&mut self) {
74 *self = match self {
75 Role::User => Role::Assistant,
76 Role::Assistant => Role::System,
77 Role::System => Role::User,
78 }
79 }
80}
81
82impl Display for Role {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
84 match self {
85 Role::User => write!(f, "user"),
86 Role::Assistant => write!(f, "assistant"),
87 Role::System => write!(f, "system"),
88 }
89 }
90}
91
92#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
93pub enum LanguageModel {
94 Cloud(CloudModel),
95 OpenAi(OpenAiModel),
96 Anthropic(AnthropicModel),
97 Ollama(OllamaModel),
98}
99
100impl Default for LanguageModel {
101 fn default() -> Self {
102 LanguageModel::Cloud(CloudModel::default())
103 }
104}
105
106impl LanguageModel {
107 pub fn telemetry_id(&self) -> String {
108 match self {
109 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
110 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
111 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
112 LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
113 }
114 }
115
116 pub fn display_name(&self) -> String {
117 match self {
118 LanguageModel::OpenAi(model) => model.display_name().into(),
119 LanguageModel::Anthropic(model) => model.display_name().into(),
120 LanguageModel::Cloud(model) => model.display_name().into(),
121 LanguageModel::Ollama(model) => model.display_name().into(),
122 }
123 }
124
125 pub fn max_token_count(&self) -> usize {
126 match self {
127 LanguageModel::OpenAi(model) => model.max_token_count(),
128 LanguageModel::Anthropic(model) => model.max_token_count(),
129 LanguageModel::Cloud(model) => model.max_token_count(),
130 LanguageModel::Ollama(model) => model.max_token_count(),
131 }
132 }
133
134 pub fn id(&self) -> &str {
135 match self {
136 LanguageModel::OpenAi(model) => model.id(),
137 LanguageModel::Anthropic(model) => model.id(),
138 LanguageModel::Cloud(model) => model.id(),
139 LanguageModel::Ollama(model) => model.id(),
140 }
141 }
142}
143
144#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
145pub struct LanguageModelRequestMessage {
146 pub role: Role,
147 pub content: String,
148}
149
150impl LanguageModelRequestMessage {
151 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
152 proto::LanguageModelRequestMessage {
153 role: match self.role {
154 Role::User => proto::LanguageModelRole::LanguageModelUser,
155 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
156 Role::System => proto::LanguageModelRole::LanguageModelSystem,
157 } as i32,
158 content: self.content.clone(),
159 tool_calls: Vec::new(),
160 tool_call_id: None,
161 }
162 }
163}
164
165#[derive(Debug, Default, Serialize)]
166pub struct LanguageModelRequest {
167 pub model: LanguageModel,
168 pub messages: Vec<LanguageModelRequestMessage>,
169 pub stop: Vec<String>,
170 pub temperature: f32,
171}
172
173impl LanguageModelRequest {
174 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
175 proto::CompleteWithLanguageModel {
176 model: self.model.id().to_string(),
177 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
178 stop: self.stop.clone(),
179 temperature: self.temperature,
180 tool_choice: None,
181 tools: Vec::new(),
182 }
183 }
184
185 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
186 pub fn preprocess(&mut self) {
187 match &self.model {
188 LanguageModel::OpenAi(_) => {}
189 LanguageModel::Anthropic(_) => {}
190 LanguageModel::Ollama(_) => {}
191 LanguageModel::Cloud(model) => match model {
192 CloudModel::Claude3Opus
193 | CloudModel::Claude3Sonnet
194 | CloudModel::Claude3Haiku
195 | CloudModel::Claude3_5Sonnet => {
196 preprocess_anthropic_request(self);
197 }
198 _ => {}
199 },
200 }
201 }
202}
203
204#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
205pub struct LanguageModelResponseMessage {
206 pub role: Option<Role>,
207 pub content: Option<String>,
208}
209
210#[derive(Deserialize, Debug)]
211pub struct LanguageModelUsage {
212 pub prompt_tokens: u32,
213 pub completion_tokens: u32,
214 pub total_tokens: u32,
215}
216
217#[derive(Deserialize, Debug)]
218pub struct LanguageModelChoiceDelta {
219 pub index: u32,
220 pub delta: LanguageModelResponseMessage,
221 pub finish_reason: Option<String>,
222}
223
224#[derive(Clone, Debug, Serialize, Deserialize)]
225struct MessageMetadata {
226 role: Role,
227 status: MessageStatus,
228}
229
230#[derive(Clone, Debug, Serialize, Deserialize)]
231enum MessageStatus {
232 Pending,
233 Done,
234 Error(SharedString),
235}
236
237/// The state pertaining to the Assistant.
238#[derive(Default)]
239struct Assistant {
240 /// Whether the Assistant is enabled.
241 enabled: bool,
242}
243
244impl Global for Assistant {}
245
246impl Assistant {
247 const NAMESPACE: &'static str = "assistant";
248
249 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
250 if self.enabled == enabled {
251 return;
252 }
253
254 self.enabled = enabled;
255
256 if !enabled {
257 CommandPaletteFilter::update_global(cx, |filter, _cx| {
258 filter.hide_namespace(Self::NAMESPACE);
259 });
260
261 return;
262 }
263
264 CommandPaletteFilter::update_global(cx, |filter, _cx| {
265 filter.show_namespace(Self::NAMESPACE);
266 });
267 }
268}
269
270pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
271 cx.set_global(Assistant::default());
272 AssistantSettings::register(cx);
273
274 cx.spawn(|mut cx| {
275 let client = client.clone();
276 async move {
277 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
278 let semantic_index = SemanticIndex::new(
279 paths::embeddings_dir().join("semantic-index-db.0.mdb"),
280 Arc::new(embedding_provider),
281 &mut cx,
282 )
283 .await?;
284 cx.update(|cx| cx.set_global(semantic_index))
285 }
286 })
287 .detach();
288
289 prompt_library::init(cx);
290 completion_provider::init(client.clone(), cx);
291 assistant_slash_command::init(cx);
292 register_slash_commands(cx);
293 assistant_panel::init(cx);
294 inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
295 terminal_inline_assistant::init(fs.clone(), client.telemetry().clone(), cx);
296 IndexedDocsRegistry::init_global(cx);
297
298 CommandPaletteFilter::update_global(cx, |filter, _cx| {
299 filter.hide_namespace(Assistant::NAMESPACE);
300 });
301 Assistant::update_global(cx, |assistant, cx| {
302 let settings = AssistantSettings::get_global(cx);
303
304 assistant.set_enabled(settings.enabled, cx);
305 });
306 cx.observe_global::<SettingsStore>(|cx| {
307 Assistant::update_global(cx, |assistant, cx| {
308 let settings = AssistantSettings::get_global(cx);
309 assistant.set_enabled(settings.enabled, cx);
310 });
311 })
312 .detach();
313}
314
315fn register_slash_commands(cx: &mut AppContext) {
316 let slash_command_registry = SlashCommandRegistry::global(cx);
317 slash_command_registry.register_command(file_command::FileSlashCommand, true);
318 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
319 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
320 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
321 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
322 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
323 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
324 slash_command_registry.register_command(term_command::TermSlashCommand, true);
325 slash_command_registry.register_command(now_command::NowSlashCommand, true);
326 slash_command_registry.register_command(diagnostics_command::DiagnosticsCommand, true);
327 slash_command_registry.register_command(docs_command::DocsSlashCommand, true);
328 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
329}
330
331pub fn humanize_token_count(count: usize) -> String {
332 match count {
333 0..=999 => count.to_string(),
334 1000..=9999 => {
335 let thousands = count / 1000;
336 let hundreds = (count % 1000 + 50) / 100;
337 if hundreds == 0 {
338 format!("{}k", thousands)
339 } else if hundreds == 10 {
340 format!("{}k", thousands + 1)
341 } else {
342 format!("{}.{}k", thousands, hundreds)
343 }
344 }
345 _ => format!("{}k", (count + 500) / 1000),
346 }
347}
348
349#[cfg(test)]
350#[ctor::ctor]
351fn init_logger() {
352 if std::env::var("RUST_LOG").is_ok() {
353 env_logger::init();
354 }
355}