1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod model_selector;
6mod prompts;
7mod saved_conversation;
8mod search;
9mod slash_command;
10mod streaming_diff;
11
12pub use assistant_panel::AssistantPanel;
13
14use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
15use client::{proto, Client};
16use command_palette_hooks::CommandPaletteFilter;
17pub(crate) use completion_provider::*;
18use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
19pub(crate) use model_selector::*;
20pub(crate) use saved_conversation::*;
21use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
22use serde::{Deserialize, Serialize};
23use settings::{Settings, SettingsStore};
24use std::{
25 fmt::{self, Display},
26 sync::Arc,
27};
28use util::paths::EMBEDDINGS_DIR;
29
30actions!(
31 assistant,
32 [
33 Assist,
34 Split,
35 CycleMessageRole,
36 QuoteSelection,
37 ToggleFocus,
38 ResetKey,
39 InlineAssist,
40 InsertActivePrompt,
41 ToggleHistory,
42 ApplyEdit,
43 ConfirmCommand,
44 ToggleModelSelector
45 ]
46);
47
48#[derive(
49 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
50)]
51struct MessageId(usize);
52
53#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
54#[serde(rename_all = "lowercase")]
55pub enum Role {
56 User,
57 Assistant,
58 System,
59}
60
61impl Role {
62 pub fn cycle(&mut self) {
63 *self = match self {
64 Role::User => Role::Assistant,
65 Role::Assistant => Role::System,
66 Role::System => Role::User,
67 }
68 }
69}
70
71impl Display for Role {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Role::User => write!(f, "user"),
75 Role::Assistant => write!(f, "assistant"),
76 Role::System => write!(f, "system"),
77 }
78 }
79}
80
81#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
82pub enum LanguageModel {
83 ZedDotDev(ZedDotDevModel),
84 OpenAi(OpenAiModel),
85 Anthropic(AnthropicModel),
86}
87
88impl Default for LanguageModel {
89 fn default() -> Self {
90 LanguageModel::ZedDotDev(ZedDotDevModel::default())
91 }
92}
93
94impl LanguageModel {
95 pub fn telemetry_id(&self) -> String {
96 match self {
97 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
98 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
99 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
100 }
101 }
102
103 pub fn display_name(&self) -> String {
104 match self {
105 LanguageModel::OpenAi(model) => model.display_name().into(),
106 LanguageModel::Anthropic(model) => model.display_name().into(),
107 LanguageModel::ZedDotDev(model) => model.display_name().into(),
108 }
109 }
110
111 pub fn max_token_count(&self) -> usize {
112 match self {
113 LanguageModel::OpenAi(model) => model.max_token_count(),
114 LanguageModel::Anthropic(model) => model.max_token_count(),
115 LanguageModel::ZedDotDev(model) => model.max_token_count(),
116 }
117 }
118
119 pub fn id(&self) -> &str {
120 match self {
121 LanguageModel::OpenAi(model) => model.id(),
122 LanguageModel::Anthropic(model) => model.id(),
123 LanguageModel::ZedDotDev(model) => model.id(),
124 }
125 }
126}
127
128#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
129pub struct LanguageModelRequestMessage {
130 pub role: Role,
131 pub content: String,
132}
133
134impl LanguageModelRequestMessage {
135 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
136 proto::LanguageModelRequestMessage {
137 role: match self.role {
138 Role::User => proto::LanguageModelRole::LanguageModelUser,
139 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
140 Role::System => proto::LanguageModelRole::LanguageModelSystem,
141 } as i32,
142 content: self.content.clone(),
143 tool_calls: Vec::new(),
144 tool_call_id: None,
145 }
146 }
147}
148
149#[derive(Debug, Default, Serialize)]
150pub struct LanguageModelRequest {
151 pub model: LanguageModel,
152 pub messages: Vec<LanguageModelRequestMessage>,
153 pub stop: Vec<String>,
154 pub temperature: f32,
155}
156
157impl LanguageModelRequest {
158 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
159 proto::CompleteWithLanguageModel {
160 model: self.model.id().to_string(),
161 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
162 stop: self.stop.clone(),
163 temperature: self.temperature,
164 tool_choice: None,
165 tools: Vec::new(),
166 }
167 }
168}
169
170#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
171pub struct LanguageModelResponseMessage {
172 pub role: Option<Role>,
173 pub content: Option<String>,
174}
175
176#[derive(Deserialize, Debug)]
177pub struct LanguageModelUsage {
178 pub prompt_tokens: u32,
179 pub completion_tokens: u32,
180 pub total_tokens: u32,
181}
182
183#[derive(Deserialize, Debug)]
184pub struct LanguageModelChoiceDelta {
185 pub index: u32,
186 pub delta: LanguageModelResponseMessage,
187 pub finish_reason: Option<String>,
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize)]
191struct MessageMetadata {
192 role: Role,
193 status: MessageStatus,
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize)]
197enum MessageStatus {
198 Pending,
199 Done,
200 Error(SharedString),
201}
202
203/// The state pertaining to the Assistant.
204#[derive(Default)]
205struct Assistant {
206 /// Whether the Assistant is enabled.
207 enabled: bool,
208}
209
210impl Global for Assistant {}
211
212impl Assistant {
213 const NAMESPACE: &'static str = "assistant";
214
215 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
216 if self.enabled == enabled {
217 return;
218 }
219
220 self.enabled = enabled;
221
222 if !enabled {
223 CommandPaletteFilter::update_global(cx, |filter, _cx| {
224 filter.hide_namespace(Self::NAMESPACE);
225 });
226
227 return;
228 }
229
230 CommandPaletteFilter::update_global(cx, |filter, _cx| {
231 filter.show_namespace(Self::NAMESPACE);
232 });
233 }
234}
235
236pub fn init(client: Arc<Client>, cx: &mut AppContext) {
237 cx.set_global(Assistant::default());
238 AssistantSettings::register(cx);
239
240 cx.spawn(|mut cx| {
241 let client = client.clone();
242 async move {
243 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
244 let semantic_index = SemanticIndex::new(
245 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
246 Arc::new(embedding_provider),
247 &mut cx,
248 )
249 .await?;
250 cx.update(|cx| cx.set_global(semantic_index))
251 }
252 })
253 .detach();
254 completion_provider::init(client, cx);
255 assistant_slash_command::init(cx);
256 assistant_panel::init(cx);
257
258 CommandPaletteFilter::update_global(cx, |filter, _cx| {
259 filter.hide_namespace(Assistant::NAMESPACE);
260 });
261 Assistant::update_global(cx, |assistant, cx| {
262 let settings = AssistantSettings::get_global(cx);
263
264 assistant.set_enabled(settings.enabled, cx);
265 });
266 cx.observe_global::<SettingsStore>(|cx| {
267 Assistant::update_global(cx, |assistant, cx| {
268 let settings = AssistantSettings::get_global(cx);
269
270 assistant.set_enabled(settings.enabled, cx);
271 });
272 })
273 .detach();
274}
275
276#[cfg(test)]
277#[ctor::ctor]
278fn init_logger() {
279 if std::env::var("RUST_LOG").is_ok() {
280 env_logger::init();
281 }
282}