1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9mod embedded_scope;
10
11pub use assistant_panel::AssistantPanel;
12use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
13use chrono::{DateTime, Local};
14use client::{proto, Client};
15use command_palette_hooks::CommandPaletteFilter;
16pub(crate) use completion_provider::*;
17use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
18pub(crate) use saved_conversation::*;
19use serde::{Deserialize, Serialize};
20use settings::{Settings, SettingsStore};
21use std::{
22 fmt::{self, Display},
23 sync::Arc,
24};
25
26actions!(
27 assistant,
28 [
29 NewConversation,
30 Assist,
31 Split,
32 CycleMessageRole,
33 QuoteSelection,
34 ToggleFocus,
35 ResetKey,
36 InlineAssist,
37 ToggleIncludeConversation,
38 ]
39);
40
41#[derive(
42 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
43)]
44struct MessageId(usize);
45
46#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
47#[serde(rename_all = "lowercase")]
48pub enum Role {
49 User,
50 Assistant,
51 System,
52}
53
54impl Role {
55 pub fn cycle(&mut self) {
56 *self = match self {
57 Role::User => Role::Assistant,
58 Role::Assistant => Role::System,
59 Role::System => Role::User,
60 }
61 }
62}
63
64impl Display for Role {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Role::User => write!(f, "user"),
68 Role::Assistant => write!(f, "assistant"),
69 Role::System => write!(f, "system"),
70 }
71 }
72}
73
74#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
75pub enum LanguageModel {
76 ZedDotDev(ZedDotDevModel),
77 OpenAi(OpenAiModel),
78}
79
80impl Default for LanguageModel {
81 fn default() -> Self {
82 LanguageModel::ZedDotDev(ZedDotDevModel::default())
83 }
84}
85
86impl LanguageModel {
87 pub fn telemetry_id(&self) -> String {
88 match self {
89 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
90 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
91 }
92 }
93
94 pub fn display_name(&self) -> String {
95 match self {
96 LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
97 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
98 }
99 }
100
101 pub fn max_token_count(&self) -> usize {
102 match self {
103 LanguageModel::OpenAi(model) => model.max_token_count(),
104 LanguageModel::ZedDotDev(model) => model.max_token_count(),
105 }
106 }
107
108 pub fn id(&self) -> &str {
109 match self {
110 LanguageModel::OpenAi(model) => model.id(),
111 LanguageModel::ZedDotDev(model) => model.id(),
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
117pub struct LanguageModelRequestMessage {
118 pub role: Role,
119 pub content: String,
120}
121
122impl LanguageModelRequestMessage {
123 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
124 proto::LanguageModelRequestMessage {
125 role: match self.role {
126 Role::User => proto::LanguageModelRole::LanguageModelUser,
127 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
128 Role::System => proto::LanguageModelRole::LanguageModelSystem,
129 } as i32,
130 content: self.content.clone(),
131 }
132 }
133}
134
135#[derive(Debug, Default, Serialize)]
136pub struct LanguageModelRequest {
137 pub model: LanguageModel,
138 pub messages: Vec<LanguageModelRequestMessage>,
139 pub stop: Vec<String>,
140 pub temperature: f32,
141}
142
143impl LanguageModelRequest {
144 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
145 proto::CompleteWithLanguageModel {
146 model: self.model.id().to_string(),
147 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
148 stop: self.stop.clone(),
149 temperature: self.temperature,
150 }
151 }
152}
153
154#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
155pub struct LanguageModelResponseMessage {
156 pub role: Option<Role>,
157 pub content: Option<String>,
158}
159
160#[derive(Deserialize, Debug)]
161pub struct LanguageModelUsage {
162 pub prompt_tokens: u32,
163 pub completion_tokens: u32,
164 pub total_tokens: u32,
165}
166
167#[derive(Deserialize, Debug)]
168pub struct LanguageModelChoiceDelta {
169 pub index: u32,
170 pub delta: LanguageModelResponseMessage,
171 pub finish_reason: Option<String>,
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize)]
175struct MessageMetadata {
176 role: Role,
177 sent_at: DateTime<Local>,
178 status: MessageStatus,
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize)]
182enum MessageStatus {
183 Pending,
184 Done,
185 Error(SharedString),
186}
187
188/// The state pertaining to the Assistant.
189#[derive(Default)]
190struct Assistant {
191 /// Whether the Assistant is enabled.
192 enabled: bool,
193}
194
195impl Global for Assistant {}
196
197impl Assistant {
198 const NAMESPACE: &'static str = "assistant";
199
200 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
201 if self.enabled == enabled {
202 return;
203 }
204
205 self.enabled = enabled;
206
207 if !enabled {
208 CommandPaletteFilter::update_global(cx, |filter, _cx| {
209 filter.hide_namespace(Self::NAMESPACE);
210 });
211
212 return;
213 }
214
215 CommandPaletteFilter::update_global(cx, |filter, _cx| {
216 filter.show_namespace(Self::NAMESPACE);
217 });
218 }
219}
220
221pub fn init(client: Arc<Client>, cx: &mut AppContext) {
222 cx.set_global(Assistant::default());
223 AssistantSettings::register(cx);
224 completion_provider::init(client, cx);
225 assistant_panel::init(cx);
226
227 CommandPaletteFilter::update_global(cx, |filter, _cx| {
228 filter.hide_namespace(Assistant::NAMESPACE);
229 });
230 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
231 let settings = AssistantSettings::get_global(cx);
232
233 assistant.set_enabled(settings.enabled, cx);
234 });
235 cx.observe_global::<SettingsStore>(|cx| {
236 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
237 let settings = AssistantSettings::get_global(cx);
238
239 assistant.set_enabled(settings.enabled, cx);
240 });
241 })
242 .detach();
243}
244
245#[cfg(test)]
246#[ctor::ctor]
247fn init_logger() {
248 if std::env::var("RUST_LOG").is_ok() {
249 env_logger::init();
250 }
251}