1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "context"
19 "encoding/json"
20 "fmt"
21 "net/http"
22 "net/url"
23
24 "github.com/gorilla/websocket"
25)
26
27// Preview. Live serves as the entry point for establishing real-time WebSocket
28// connections to the API. It manages the initial handshake and setup process.
29//
30// It is initiated when creating a client via [NewClient]. You don't need to
31// create a new Live object directly. Access it through the `Live` field of a
32// `Client` instance.
33//
34// client, _ := genai.NewClient(ctx, &genai.ClientConfig{})
35// session, _ := client.Live.Connect(ctx, model, &genai.LiveConnectConfig{}).
36type Live struct {
37 apiClient *apiClient
38}
39
40// Preview. Session represents an active, real-time WebSocket connection to the
41// Generative AI API. It provides methods for sending client messages and
42// receiving server messages over the established connection.
43type Session struct {
44 conn *websocket.Conn
45 apiClient *apiClient
46}
47
48// Preview. Connect establishes a WebSocket connection to the specified
49// model with the given configuration. It sends the initial
50// setup message and returns a [Session] object representing the connection.
51func (r *Live) Connect(context context.Context, model string, config *LiveConnectConfig) (*Session, error) {
52 httpOptions := r.apiClient.clientConfig.HTTPOptions
53 if httpOptions.APIVersion == "" {
54 return nil, fmt.Errorf("live module requires APIVersion to be set. You can set APIVersion to v1beta1 for BackendVertexAI or v1apha for BackendGeminiAPI")
55 }
56 baseURL, err := url.Parse(httpOptions.BaseURL)
57 if err != nil {
58 return nil, fmt.Errorf("failed to parse base URL: %w", err)
59 }
60 scheme := baseURL.Scheme
61 // Avoid overwrite schema if websocket scheme is already specified.
62 if scheme != "wss" && scheme != "ws" {
63 scheme = "wss"
64 }
65
66 var u url.URL
67 // TODO(b/406076143): Support function level httpOptions.
68 var header http.Header = mergeHeaders(&httpOptions, nil)
69 if r.apiClient.clientConfig.Backend == BackendVertexAI {
70 token, err := r.apiClient.clientConfig.Credentials.Token(context)
71 if err != nil {
72 return nil, fmt.Errorf("failed to get token: %w", err)
73 }
74 header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Value))
75 u = url.URL{
76 Scheme: scheme,
77 Host: baseURL.Host,
78 Path: fmt.Sprintf("%s/ws/google.cloud.aiplatform.%s.LlmBidiService/BidiGenerateContent", baseURL.Path, httpOptions.APIVersion),
79 }
80 } else {
81 u = url.URL{
82 Scheme: scheme,
83 Host: baseURL.Host,
84 Path: fmt.Sprintf("%s/ws/google.ai.generativelanguage.%s.GenerativeService.BidiGenerateContent", baseURL.Path, httpOptions.APIVersion),
85 RawQuery: fmt.Sprintf("key=%s", r.apiClient.clientConfig.APIKey),
86 }
87 }
88
89 conn, _, err := websocket.DefaultDialer.Dial(u.String(), header)
90 if err != nil {
91 return nil, fmt.Errorf("Connect to %s failed: %w", u.String(), err)
92 }
93 s := &Session{
94 conn: conn,
95 apiClient: r.apiClient,
96 }
97 modelFullName, err := tModelFullName(r.apiClient, model)
98 if err != nil {
99 return nil, err
100 }
101 kwargs := map[string]any{"model": modelFullName, "config": config}
102 parameterMap := make(map[string]any)
103 err = deepMarshal(kwargs, ¶meterMap)
104 if err != nil {
105 return nil, err
106 }
107
108 var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
109 if r.apiClient.clientConfig.Backend == BackendVertexAI {
110 toConverter = liveConnectParametersToVertex
111 } else {
112 toConverter = liveConnectParametersToMldev
113 }
114 body, err := toConverter(r.apiClient, parameterMap, nil)
115 if err != nil {
116 return nil, err
117 }
118 delete(body, "config")
119
120 clientBytes, err := json.Marshal(body)
121 if err != nil {
122 return nil, fmt.Errorf("marshal LiveClientSetup failed: %w", err)
123 }
124 err = s.conn.WriteMessage(websocket.TextMessage, clientBytes)
125 if err != nil {
126 return nil, fmt.Errorf("failed to write LiveClientSetup: %w", err)
127 }
128 return s, nil
129}
130
131// Preview. LiveClientContentInput is the input for [SendClientContent].
132type LiveClientContentInput = LiveSendClientContentParameters
133
134// Preview. SendClientContent transmits non-realtime, turn-based content to the model
135// over the established WebSocket connection.
136//
137// There are two primary ways to send messages in a live session:
138// [SendClientContent] and [SendRealtimeInput].
139//
140// Messages sent via [SendClientContent] are added to the model's context strictly
141// **in the order they are sent**. A conversation using [SendClientContent] is
142// similar to using the [Chat.SendMessageStream] method, but the conversation
143// history state is managed by the API server.
144//
145// Due to this ordering guarantee, the model might not respond as quickly to
146// [SendClientContent] messages compared to SendRealtimeInput messages. This latency
147// difference is most noticeable when sending content that requires significant
148// preprocessing, such as images.
149//
150// [SendClientContent] accepts a LiveClientContentInput which contains a list of
151// [*Content] objects, offering more flexibility than the [*Blob] used by
152// SendRealtimeInput.
153//
154// Key use cases for [SendClientContent] over SendRealtimeInput include:
155// - Pre-populating the conversation context (including sending content types
156// not supported by realtime messages) before starting a realtime interaction.
157// - Conducting a non-realtime conversation, similar to client.Chats.SendMessage,
158// using the live API infrastructure.
159//
160// Caution: Interleaving [SendClientContent] and SendRealtimeInput within the
161// same conversation is not recommended and may lead to unexpected behavior.
162//
163// The input parameter of type [LiveClientContentInput] contains:
164// - Turns: A slice of [*Content] objects representing the message(s) to send.
165// - TurnComplete: If true (the default), the model will reply immediately.
166// If false, the model waits for subsequent SendClientContent calls until
167// one is sent with TurnComplete set to true.
168func (s *Session) SendClientContent(input LiveClientContentInput) error {
169 return s.send(input.toLiveClientMessage())
170}
171
172// Preview. LiveRealtimeInput is the input for [SendRealtimeInput].
173type LiveRealtimeInput = LiveSendRealtimeInputParameters
174
175// Preview. SendRealtimeInput transmits realtime audio chunks and video frames (images)
176// to the model over the established WebSocket connection.
177//
178// Use SendRealtimeInput for streaming audio and video data. The API automatically
179// responds to audio based on voice activity detection (VAD).
180//
181// SendRealtimeInput is optimized for responsiveness, potentially at the expense
182// of deterministic ordering. Audio and video tokens are added to the model's
183// context as they become available, allowing for faster interaction.
184//
185// It accepts a [LiveRealtimeInput] parameter containing the media data.
186// Only one argument (e.g., Media, Audio, Video, Text) should be provided per call.
187func (s *Session) SendRealtimeInput(input LiveRealtimeInput) error {
188 parameterMap := make(map[string]any)
189 err := deepMarshal(input, ¶meterMap)
190 if err != nil {
191 return err
192 }
193
194 var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
195 if s.apiClient.clientConfig.Backend == BackendVertexAI {
196 toConverter = liveSendRealtimeInputParametersToVertex
197 } else {
198 toConverter = liveSendRealtimeInputParametersToMldev
199 }
200 body, err := toConverter(s.apiClient, parameterMap, nil)
201 if err != nil {
202 return err
203 }
204
205 data, err := json.Marshal(map[string]any{"realtimeInput": body})
206 if err != nil {
207 return fmt.Errorf("marshal client message error: %w", err)
208 }
209 return s.conn.WriteMessage(websocket.TextMessage, []byte(data))
210}
211
212// Preview. LiveToolResponseInput is the input for [SendToolResponse].
213type LiveToolResponseInput = LiveSendToolResponseParameters
214
215// Preview. SendToolResponse transmits a [LiveClientToolResponse] over the established WebSocket connection.
216//
217// Use SendToolResponse to reply to [LiveServerToolCall] messages received from the server.
218//
219// To define the available tools for the session, set the [LiveConnectConfig.Tools]
220// field when establishing the connection via [Live.Connect].
221func (s *Session) SendToolResponse(input LiveToolResponseInput) error {
222 return s.send(input.toLiveClientMessage())
223}
224
225// Send transmits a LiveClientMessage over the established connection.
226// It returns an error if sending the message fails.
227func (s *Session) send(input *LiveClientMessage) error {
228 if input.Setup != nil {
229 return fmt.Errorf("message SetUp is not supported in Send(). Use Connect() instead")
230 }
231
232 parameterMap := make(map[string]any)
233 err := deepMarshal(input, ¶meterMap)
234 if err != nil {
235 return err
236 }
237
238 var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
239 if s.apiClient.clientConfig.Backend == BackendVertexAI {
240 toConverter = liveClientMessageToVertex
241 } else {
242 toConverter = liveClientMessageToMldev
243 }
244 body, err := toConverter(s.apiClient, parameterMap, nil)
245 if err != nil {
246 return err
247 }
248
249 data, err := json.Marshal(body)
250 if err != nil {
251 return fmt.Errorf("marshal client message error: %w", err)
252 }
253 return s.conn.WriteMessage(websocket.TextMessage, []byte(data))
254}
255
256// Preview. Receive reads a LiveServerMessage from the connection.
257//
258// This method blocks until a message is received from the server.
259// The returned message represents a part of or a complete model turn.
260// If the received message is a [LiveServerToolCall], the user must call
261// [SendToolResponse] to provide the function execution result and continue the turn.
262func (s *Session) Receive() (*LiveServerMessage, error) {
263 messageType, msgBytes, err := s.conn.ReadMessage()
264 if err != nil {
265 return nil, err
266 }
267 responseMap := make(map[string]any)
268 err = json.Unmarshal(msgBytes, &responseMap)
269 if err != nil {
270 return nil, fmt.Errorf("invalid message format. Error %w. messageType: %d, message: %s", err, messageType, msgBytes)
271 }
272 if responseMap["error"] != nil {
273 return nil, fmt.Errorf("received error in response: %v", string(msgBytes))
274 }
275
276 var fromConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
277 if s.apiClient.clientConfig.Backend == BackendVertexAI {
278 fromConverter = liveServerMessageFromVertex
279 } else {
280 fromConverter = liveServerMessageFromMldev
281 }
282 responseMap, err = fromConverter(s.apiClient, responseMap, nil)
283 if err != nil {
284 return nil, err
285 }
286
287 var message = new(LiveServerMessage)
288 err = mapToStruct(responseMap, message)
289 if err != nil {
290 return nil, err
291 }
292 return message, err
293}
294
295// Preview. Close terminates the connection.
296func (s *Session) Close() error {
297 if s != nil && s.conn != nil {
298 return s.conn.Close()
299 }
300 return nil
301}