transformer.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"fmt"
 19	"log"
 20	"regexp"
 21	"strings"
 22)
 23
 24func tResourceName(ac *apiClient, resourceName string, collectionIdentifier string, collectionHierarchyDepth int) string {
 25	shouldPrependCollectionIdentifier := !strings.HasPrefix(resourceName, collectionIdentifier+"/") &&
 26		strings.Count(collectionIdentifier+"/"+resourceName, "/")+1 == collectionHierarchyDepth
 27
 28	switch ac.clientConfig.Backend {
 29	case BackendVertexAI:
 30		if strings.HasPrefix(resourceName, "projects/") {
 31			return resourceName
 32		} else if strings.HasPrefix(resourceName, "locations/") {
 33			return fmt.Sprintf("projects/%s/%s", ac.clientConfig.Project, resourceName)
 34		} else if strings.HasPrefix(resourceName, collectionIdentifier+"/") {
 35			return fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, resourceName)
 36		} else if shouldPrependCollectionIdentifier {
 37			return fmt.Sprintf("projects/%s/locations/%s/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, collectionIdentifier, resourceName)
 38		} else {
 39			return resourceName
 40		}
 41	default:
 42		if shouldPrependCollectionIdentifier {
 43			return fmt.Sprintf("%s/%s", collectionIdentifier, resourceName)
 44		} else {
 45			return resourceName
 46		}
 47	}
 48}
 49
 50func tCachedContentName(ac *apiClient, name any) (string, error) {
 51	return tResourceName(ac, name.(string), "cachedContents", 2), nil
 52}
 53
 54func tModel(ac *apiClient, origin any) (string, error) {
 55	switch model := origin.(type) {
 56	case string:
 57		if model == "" {
 58			return "", fmt.Errorf("tModel: model is empty")
 59		}
 60		if ac.clientConfig.Backend == BackendVertexAI {
 61			if strings.HasPrefix(model, "projects/") || strings.HasPrefix(model, "models/") || strings.HasPrefix(model, "publishers/") {
 62				return model, nil
 63			} else if strings.Contains(model, "/") {
 64				parts := strings.SplitN(model, "/", 2)
 65				return fmt.Sprintf("publishers/%s/models/%s", parts[0], parts[1]), nil
 66			} else {
 67				return fmt.Sprintf("publishers/google/models/%s", model), nil
 68			}
 69		} else {
 70			if strings.HasPrefix(model, "models/") || strings.HasPrefix(model, "tunedModels/") {
 71				return model, nil
 72			} else {
 73				return fmt.Sprintf("models/%s", model), nil
 74			}
 75		}
 76	default:
 77		return "", fmt.Errorf("tModel: model is not a string")
 78	}
 79}
 80
 81func tModelFullName(ac *apiClient, origin any) (string, error) {
 82	switch model := origin.(type) {
 83	case string:
 84		name, err := tModel(ac, model)
 85		if err != nil {
 86			return "", fmt.Errorf("tModelFullName: %w", err)
 87		}
 88		if strings.HasPrefix(name, "publishers/") && ac.clientConfig.Backend == BackendVertexAI {
 89			return fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, name), nil
 90		} else if strings.HasPrefix(name, "models/") && ac.clientConfig.Backend == BackendVertexAI {
 91			return fmt.Sprintf("projects/%s/locations/%s/publishers/google/%s", ac.clientConfig.Project, ac.clientConfig.Location, name), nil
 92		} else {
 93			return name, nil
 94		}
 95	default:
 96		return "", fmt.Errorf("tModelFullName: model is not a string")
 97	}
 98}
 99
100func tCachesModel(ac *apiClient, origin any) (string, error) {
101	return tModelFullName(ac, origin)
102}
103
104func tContent(_ *apiClient, content any) (any, error) {
105	return content, nil
106}
107
108func tContents(_ *apiClient, contents any) (any, error) {
109	return contents, nil
110}
111
112func tTool(_ *apiClient, tool any) (any, error) {
113	return tool, nil
114}
115
116func tTools(_ *apiClient, tools any) (any, error) {
117	return tools, nil
118}
119
120func tSchema(apiClient *apiClient, origin any) (any, error) {
121	return origin, nil
122}
123
124func tSpeechConfig(_ *apiClient, speechConfig any) (any, error) {
125	return speechConfig, nil
126}
127
128func tBytes(_ *apiClient, fromImageBytes any) (any, error) {
129	// TODO(b/389133914): Remove dummy bytes converter.
130	return fromImageBytes, nil
131}
132
133func tContentsForEmbed(ac *apiClient, contents any) (any, error) {
134	if ac.clientConfig.Backend == BackendVertexAI {
135		switch v := contents.(type) {
136		case []any:
137			texts := []string{}
138			for _, content := range v {
139				parts, ok := content.(map[string]any)["parts"].([]any)
140				if !ok || len(parts) == 0 {
141					return nil, fmt.Errorf("tContentsForEmbed: content parts is not a non-empty list")
142				}
143				text, ok := parts[0].(map[string]any)["text"].(string)
144				if !ok {
145					return nil, fmt.Errorf("tContentsForEmbed: content part text is not a string")
146				}
147				texts = append(texts, text)
148			}
149			return texts, nil
150		default:
151			return nil, fmt.Errorf("tContentsForEmbed: contents is not a list")
152		}
153	} else {
154		return contents, nil
155	}
156}
157
158func tModelsURL(ac *apiClient, baseModels any) (string, error) {
159	if ac.clientConfig.Backend == BackendVertexAI {
160		if baseModels.(bool) {
161			return "publishers/google/models", nil
162		} else {
163			return "models", nil
164		}
165	} else {
166		if baseModels.(bool) {
167			return "models", nil
168		} else {
169			return "tunedModels", nil
170		}
171	}
172}
173
174func tExtractModels(ac *apiClient, response any) (any, error) {
175	switch response := response.(type) {
176	case map[string]any:
177		if models, ok := response["models"]; ok {
178			return models, nil
179		} else if tunedModels, ok := response["tunedModels"]; ok {
180			return tunedModels, nil
181		} else if publisherModels, ok := response["publisherModels"]; ok {
182			return publisherModels, nil
183		} else {
184			log.Printf("Warning: Cannot find the models type(models, tunedModels, publisherModels) for response: %s", response)
185			return []any{}, nil
186		}
187	default:
188		return nil, fmt.Errorf("tExtractModels: response is not a map")
189	}
190}
191
192func tFileName(ac *apiClient, name any) (string, error) {
193	switch name := name.(type) {
194	case string:
195		{
196			if strings.HasPrefix(name, "https://") || strings.HasPrefix(name, "http://") {
197				parts := strings.SplitN(name, "files/", 2)
198				if len(parts) < 2 {
199					return "", fmt.Errorf("could not find 'files/' in URI: %s", name)
200				}
201				suffix := parts[1]
202				re := regexp.MustCompile("^[a-z0-9]+")
203				match := re.FindStringSubmatch(suffix)
204				if len(match) == 0 {
205					return "", fmt.Errorf("could not extract file name from URI: %s", name)
206				}
207				name = match[0]
208			} else if strings.HasPrefix(name, "files/") {
209				name = strings.TrimPrefix(name, "files/")
210			}
211			return name, nil
212		}
213	default:
214		return "", fmt.Errorf("tFileName: name is not a string")
215	}
216}
217
218func tBlobs(ac *apiClient, blobs any) (any, error) {
219	switch blobs := blobs.(type) {
220	case []any:
221		// The only use case of this tBlobs function is for LiveSendRealtimeInputParameters.Media field.
222		// The Media field is a Blob type, not a list of Blob. So this branch will never be executed.
223		// If tBlobs function is used for other purposes in the future, uncomment the following line to
224		// enable this branch.
225		// applyConverterToSlice(ac, blobs, tBlob)
226		return nil, fmt.Errorf("unimplemented")
227	default:
228		blob, err := tBlob(ac, blobs)
229		if err != nil {
230			return nil, err
231		}
232		return []any{blob}, nil
233	}
234}
235
236func tBlob(ac *apiClient, blob any) (any, error) {
237	return blob, nil
238}
239
240func tImageBlob(ac *apiClient, blob any) (any, error) {
241	switch blob := blob.(type) {
242	case map[string]any:
243		if strings.HasPrefix(blob["mimeType"].(string), "image/") {
244			return blob, nil
245		}
246		return nil, fmt.Errorf("Unsupported mime type: %s", blob["mimeType"])
247	default:
248		return nil, fmt.Errorf("tImageBlob: blob is not a map")
249	}
250}
251
252func tAudioBlob(ac *apiClient, blob any) (any, error) {
253	switch blob := blob.(type) {
254	case map[string]any:
255		if strings.HasPrefix(blob["mimeType"].(string), "audio/") {
256			return blob, nil
257		}
258		return nil, fmt.Errorf("Unsupported mime type: %s", blob["mimeType"])
259	default:
260		return nil, fmt.Errorf("tAudioBlob: blob is not a map")
261	}
262}