role.rs

 1use serde::{Deserialize, Serialize};
 2use std::fmt::{self, Display};
 3
 4#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq, Hash)]
 5#[serde(rename_all = "lowercase")]
 6pub enum Role {
 7    User,
 8    Assistant,
 9    System,
10}
11
12impl Role {
13    pub fn from_proto(role: i32) -> Role {
14        match proto::LanguageModelRole::from_i32(role) {
15            Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
16            Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
17            Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
18            None => Role::User,
19        }
20    }
21
22    pub const fn to_proto(self) -> proto::LanguageModelRole {
23        match self {
24            Role::User => proto::LanguageModelRole::LanguageModelUser,
25            Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
26            Role::System => proto::LanguageModelRole::LanguageModelSystem,
27        }
28    }
29
30    pub const fn cycle(self) -> Role {
31        match self {
32            Role::User => Role::Assistant,
33            Role::Assistant => Role::System,
34            Role::System => Role::User,
35        }
36    }
37}
38
39impl Display for Role {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Role::User => write!(f, "user"),
43            Role::Assistant => write!(f, "assistant"),
44            Role::System => write!(f, "system"),
45        }
46    }
47}