live.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"context"
 19	"encoding/json"
 20	"fmt"
 21	"net/http"
 22	"net/url"
 23
 24	"github.com/gorilla/websocket"
 25)
 26
 27// Preview. Live serves as the entry point for establishing real-time WebSocket
 28// connections to the API. It manages the initial handshake and setup process.
 29//
 30// It is initiated when creating a client via [NewClient]. You don't need to
 31// create a new Live object directly. Access it through the `Live` field of a
 32// `Client` instance.
 33//
 34//	client, _ := genai.NewClient(ctx, &genai.ClientConfig{})
 35//	session, _ := client.Live.Connect(ctx, model, &genai.LiveConnectConfig{}).
 36type Live struct {
 37	apiClient *apiClient
 38}
 39
 40// Preview. Session represents an active, real-time WebSocket connection to the
 41// Generative AI API. It provides methods for sending client messages and
 42// receiving server messages over the established connection.
 43type Session struct {
 44	conn      *websocket.Conn
 45	apiClient *apiClient
 46}
 47
 48// Preview. Connect establishes a WebSocket connection to the specified
 49// model with the given configuration. It sends the initial
 50// setup message and returns a [Session] object representing the connection.
 51func (r *Live) Connect(context context.Context, model string, config *LiveConnectConfig) (*Session, error) {
 52	httpOptions := r.apiClient.clientConfig.HTTPOptions
 53	if httpOptions.APIVersion == "" {
 54		return nil, fmt.Errorf("live module requires APIVersion to be set. You can set APIVersion to v1beta1 for BackendVertexAI or v1apha for BackendGeminiAPI")
 55	}
 56	baseURL, err := url.Parse(httpOptions.BaseURL)
 57	if err != nil {
 58		return nil, fmt.Errorf("failed to parse base URL: %w", err)
 59	}
 60	scheme := baseURL.Scheme
 61	// Avoid overwrite schema if websocket scheme is already specified.
 62	if scheme != "wss" && scheme != "ws" {
 63		scheme = "wss"
 64	}
 65
 66	var u url.URL
 67	// TODO(b/406076143): Support function level httpOptions.
 68	var header http.Header = mergeHeaders(&httpOptions, nil)
 69	if r.apiClient.clientConfig.Backend == BackendVertexAI {
 70		token, err := r.apiClient.clientConfig.Credentials.Token(context)
 71		if err != nil {
 72			return nil, fmt.Errorf("failed to get token: %w", err)
 73		}
 74		header.Set("Authorization", fmt.Sprintf("Bearer %s", token.Value))
 75		u = url.URL{
 76			Scheme: scheme,
 77			Host:   baseURL.Host,
 78			Path:   fmt.Sprintf("%s/ws/google.cloud.aiplatform.%s.LlmBidiService/BidiGenerateContent", baseURL.Path, httpOptions.APIVersion),
 79		}
 80	} else {
 81		u = url.URL{
 82			Scheme:   scheme,
 83			Host:     baseURL.Host,
 84			Path:     fmt.Sprintf("%s/ws/google.ai.generativelanguage.%s.GenerativeService.BidiGenerateContent", baseURL.Path, httpOptions.APIVersion),
 85			RawQuery: fmt.Sprintf("key=%s", r.apiClient.clientConfig.APIKey),
 86		}
 87	}
 88
 89	conn, _, err := websocket.DefaultDialer.Dial(u.String(), header)
 90	if err != nil {
 91		return nil, fmt.Errorf("Connect to %s failed: %w", u.String(), err)
 92	}
 93	s := &Session{
 94		conn:      conn,
 95		apiClient: r.apiClient,
 96	}
 97	modelFullName, err := tModelFullName(r.apiClient, model)
 98	if err != nil {
 99		return nil, err
100	}
101	kwargs := map[string]any{"model": modelFullName, "config": config}
102	parameterMap := make(map[string]any)
103	err = deepMarshal(kwargs, &parameterMap)
104	if err != nil {
105		return nil, err
106	}
107
108	var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
109	if r.apiClient.clientConfig.Backend == BackendVertexAI {
110		toConverter = liveConnectParametersToVertex
111	} else {
112		toConverter = liveConnectParametersToMldev
113	}
114	body, err := toConverter(r.apiClient, parameterMap, nil)
115	if err != nil {
116		return nil, err
117	}
118	delete(body, "config")
119
120	clientBytes, err := json.Marshal(body)
121	if err != nil {
122		return nil, fmt.Errorf("marshal LiveClientSetup failed: %w", err)
123	}
124	err = s.conn.WriteMessage(websocket.TextMessage, clientBytes)
125	if err != nil {
126		return nil, fmt.Errorf("failed to write LiveClientSetup: %w", err)
127	}
128	return s, nil
129}
130
131// Preview. LiveClientContentInput is the input for [SendClientContent].
132type LiveClientContentInput = LiveSendClientContentParameters
133
134// Preview. SendClientContent transmits non-realtime, turn-based content to the model
135// over the established WebSocket connection.
136//
137// There are two primary ways to send messages in a live session:
138// [SendClientContent] and [SendRealtimeInput].
139//
140// Messages sent via [SendClientContent] are added to the model's context strictly
141// **in the order they are sent**. A conversation using [SendClientContent] is
142// similar to using the [Chat.SendMessageStream] method, but the conversation
143// history state is managed by the API server.
144//
145// Due to this ordering guarantee, the model might not respond as quickly to
146// [SendClientContent] messages compared to SendRealtimeInput messages. This latency
147// difference is most noticeable when sending content that requires significant
148// preprocessing, such as images.
149//
150// [SendClientContent] accepts a LiveClientContentInput which contains a list of
151// [*Content] objects, offering more flexibility than the [*Blob] used by
152// SendRealtimeInput.
153//
154// Key use cases for [SendClientContent] over SendRealtimeInput include:
155//   - Pre-populating the conversation context (including sending content types
156//     not supported by realtime messages) before starting a realtime interaction.
157//   - Conducting a non-realtime conversation, similar to client.Chats.SendMessage,
158//     using the live API infrastructure.
159//
160// Caution: Interleaving [SendClientContent] and SendRealtimeInput within the
161// same conversation is not recommended and may lead to unexpected behavior.
162//
163// The input parameter of type [LiveClientContentInput] contains:
164//   - Turns: A slice of [*Content] objects representing the message(s) to send.
165//   - TurnComplete: If true (the default), the model will reply immediately.
166//     If false, the model waits for subsequent SendClientContent calls until
167//     one is sent with TurnComplete set to true.
168func (s *Session) SendClientContent(input LiveClientContentInput) error {
169	return s.send(input.toLiveClientMessage())
170}
171
172// Preview. LiveRealtimeInput is the input for [SendRealtimeInput].
173type LiveRealtimeInput = LiveSendRealtimeInputParameters
174
175// Preview. SendRealtimeInput transmits realtime audio chunks and video frames (images)
176// to the model over the established WebSocket connection.
177//
178// Use SendRealtimeInput for streaming audio and video data. The API automatically
179// responds to audio based on voice activity detection (VAD).
180//
181// SendRealtimeInput is optimized for responsiveness, potentially at the expense
182// of deterministic ordering. Audio and video tokens are added to the model's
183// context as they become available, allowing for faster interaction.
184//
185// It accepts a [LiveRealtimeInput] parameter containing the media data.
186// Only one argument (e.g., Media, Audio, Video, Text) should be provided per call.
187func (s *Session) SendRealtimeInput(input LiveRealtimeInput) error {
188	parameterMap := make(map[string]any)
189	err := deepMarshal(input, &parameterMap)
190	if err != nil {
191		return err
192	}
193
194	var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
195	if s.apiClient.clientConfig.Backend == BackendVertexAI {
196		toConverter = liveSendRealtimeInputParametersToVertex
197	} else {
198		toConverter = liveSendRealtimeInputParametersToMldev
199	}
200	body, err := toConverter(s.apiClient, parameterMap, nil)
201	if err != nil {
202		return err
203	}
204
205	data, err := json.Marshal(map[string]any{"realtimeInput": body})
206	if err != nil {
207		return fmt.Errorf("marshal client message error: %w", err)
208	}
209	return s.conn.WriteMessage(websocket.TextMessage, []byte(data))
210}
211
212// Preview. LiveToolResponseInput is the input for [SendToolResponse].
213type LiveToolResponseInput = LiveSendToolResponseParameters
214
215// Preview. SendToolResponse transmits a [LiveClientToolResponse] over the established WebSocket connection.
216//
217// Use SendToolResponse to reply to [LiveServerToolCall] messages received from the server.
218//
219// To define the available tools for the session, set the [LiveConnectConfig.Tools]
220// field when establishing the connection via [Live.Connect].
221func (s *Session) SendToolResponse(input LiveToolResponseInput) error {
222	return s.send(input.toLiveClientMessage())
223}
224
225// Send transmits a LiveClientMessage over the established connection.
226// It returns an error if sending the message fails.
227func (s *Session) send(input *LiveClientMessage) error {
228	if input.Setup != nil {
229		return fmt.Errorf("message SetUp is not supported in Send(). Use Connect() instead")
230	}
231
232	parameterMap := make(map[string]any)
233	err := deepMarshal(input, &parameterMap)
234	if err != nil {
235		return err
236	}
237
238	var toConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
239	if s.apiClient.clientConfig.Backend == BackendVertexAI {
240		toConverter = liveClientMessageToVertex
241	} else {
242		toConverter = liveClientMessageToMldev
243	}
244	body, err := toConverter(s.apiClient, parameterMap, nil)
245	if err != nil {
246		return err
247	}
248
249	data, err := json.Marshal(body)
250	if err != nil {
251		return fmt.Errorf("marshal client message error: %w", err)
252	}
253	return s.conn.WriteMessage(websocket.TextMessage, []byte(data))
254}
255
256// Preview. Receive reads a LiveServerMessage from the connection.
257//
258// This method blocks until a message is received from the server.
259// The returned message represents a part of or a complete model turn.
260// If the received message is a [LiveServerToolCall], the user must call
261// [SendToolResponse] to provide the function execution result and continue the turn.
262func (s *Session) Receive() (*LiveServerMessage, error) {
263	messageType, msgBytes, err := s.conn.ReadMessage()
264	if err != nil {
265		return nil, err
266	}
267	responseMap := make(map[string]any)
268	err = json.Unmarshal(msgBytes, &responseMap)
269	if err != nil {
270		return nil, fmt.Errorf("invalid message format. Error %w. messageType: %d, message: %s", err, messageType, msgBytes)
271	}
272	if responseMap["error"] != nil {
273		return nil, fmt.Errorf("received error in response: %v", string(msgBytes))
274	}
275
276	var fromConverter func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
277	if s.apiClient.clientConfig.Backend == BackendVertexAI {
278		fromConverter = liveServerMessageFromVertex
279	} else {
280		fromConverter = liveServerMessageFromMldev
281	}
282	responseMap, err = fromConverter(s.apiClient, responseMap, nil)
283	if err != nil {
284		return nil, err
285	}
286
287	var message = new(LiveServerMessage)
288	err = mapToStruct(responseMap, message)
289	if err != nil {
290		return nil, err
291	}
292	return message, err
293}
294
295// Preview. Close terminates the connection.
296func (s *Session) Close() error {
297	if s != nil && s.conn != nil {
298		return s.conn.Close()
299	}
300	return nil
301}