1use serde::{Deserialize, Serialize};
2use std::fmt::{self, Display};
3
4#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq, Hash)]
5#[serde(rename_all = "lowercase")]
6pub enum Role {
7 User,
8 Assistant,
9 System,
10}
11
12impl Role {
13 pub fn from_proto(role: i32) -> Role {
14 match proto::LanguageModelRole::from_i32(role) {
15 Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
16 Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
17 Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
18 None => Role::User,
19 }
20 }
21
22 pub const fn to_proto(self) -> proto::LanguageModelRole {
23 match self {
24 Role::User => proto::LanguageModelRole::LanguageModelUser,
25 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
26 Role::System => proto::LanguageModelRole::LanguageModelSystem,
27 }
28 }
29
30 pub const fn cycle(self) -> Role {
31 match self {
32 Role::User => Role::Assistant,
33 Role::Assistant => Role::System,
34 Role::System => Role::User,
35 }
36 }
37}
38
39impl Display for Role {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Role::User => write!(f, "user"),
43 Role::Assistant => write!(f, "assistant"),
44 Role::System => write!(f, "system"),
45 }
46 }
47}