gemini.go

  1package gemini
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10)
 11
 12// https://ai.google.dev/api/generate-content#request-body
 13type Request struct {
 14	Contents          []Content         `json:"contents"`
 15	Tools             []Tool            `json:"tools,omitempty"`
 16	SystemInstruction *Content          `json:"systemInstruction,omitempty"`
 17	GenerationConfig  *GenerationConfig `json:"generationConfig,omitempty"`
 18	CachedContent     string            `json:"cachedContent,omitempty"` // format: "cachedContents/{name}"
 19	// ToolConfig has been left out because it does not appear to be useful.
 20}
 21
 22// https://ai.google.dev/api/generate-content#response-body
 23type Response struct {
 24	Candidates []Candidate `json:"candidates"`
 25	headers    http.Header // captured HTTP response headers
 26}
 27
 28// Header returns the HTTP response headers.
 29func (r *Response) Header() http.Header {
 30	return r.headers
 31}
 32
 33type Candidate struct {
 34	Content Content `json:"content"`
 35}
 36
 37type Content struct {
 38	Parts []Part `json:"parts"`
 39	Role  string `json:"role,omitempty"`
 40}
 41
 42// Part is a part of the content.
 43// This is a union data structure, only one-of the fields can be set.
 44type Part struct {
 45	Text                string               `json:"text,omitempty"`
 46	FunctionCall        *FunctionCall        `json:"functionCall,omitempty"`
 47	FunctionResponse    *FunctionResponse    `json:"functionResponse,omitempty"`
 48	ExecutableCode      *ExecutableCode      `json:"executableCode,omitempty"`
 49	CodeExecutionResult *CodeExecutionResult `json:"codeExecutionResult,omitempty"`
 50	// TODO inlineData
 51	// TODO fileData
 52}
 53
 54type FunctionCall struct {
 55	Name string         `json:"name"`
 56	Args map[string]any `json:"args"`
 57}
 58
 59type FunctionResponse struct {
 60	Name     string         `json:"name"`
 61	Response map[string]any `json:"response"`
 62}
 63
 64type ExecutableCode struct {
 65	Language Language `json:"language"`
 66	Code     string   `json:"code"`
 67}
 68
 69type Language int
 70
 71const (
 72	LanguageUnspecified Language = 0
 73	LanguagePython      Language = 1 // python >= 3.10 with numpy and simpy
 74)
 75
 76type CodeExecutionResult struct {
 77	Outcome Outcome `json:"outcome"`
 78	Output  string  `json:"output"`
 79}
 80
 81type Outcome int
 82
 83const (
 84	OutcomeUnspecified      Outcome = 0
 85	OutcomeOK               Outcome = 1
 86	OutcomeFailed           Outcome = 2
 87	OutcomeDeadlineExceeded Outcome = 3
 88)
 89
 90// https://ai.google.dev/api/generate-content#v1beta.GenerationConfig
 91type GenerationConfig struct {
 92	ResponseMimeType string  `json:"responseMimeType,omitempty"` // text/plain, application/json, or text/x.enum
 93	ResponseSchema   *Schema `json:"responseSchema,omitempty"`   // for JSON
 94}
 95
 96// https://ai.google.dev/api/caching#Tool
 97type Tool struct {
 98	FunctionDeclarations []FunctionDeclaration `json:"functionDeclarations"`
 99	CodeExecution        *struct{}             `json:"codeExecution,omitempty"` // if present, enables the model to execute code
100	// TODO googleSearchRetrieval https://ai.google.dev/api/caching#GoogleSearchRetrieval
101}
102
103// https://ai.google.dev/api/caching#FunctionDeclaration
104type FunctionDeclaration struct {
105	Name        string `json:"name"`
106	Description string `json:"description"`
107	Parameters  Schema `json:"parameters"`
108}
109
110// https://ai.google.dev/api/caching#Schema
111type Schema struct {
112	Type        DataType          `json:"type"`
113	Format      string            `json:"string,omitempty"` // for NUMBER type: float, double for INTEGER type: int32, int64 for STRING type: enum
114	Description string            `json:"description,omitempty"`
115	Nullable    *bool             `json:"nullable,omitempty"`
116	Enum        []string          `json:"enum,omitempty"`
117	MaxItems    string            `json:"maxItems,omitempty"`   // for ARRAY
118	MinItems    string            `json:"minItems,omitempty"`   // for ARRAY
119	Properties  map[string]Schema `json:"properties,omitempty"` // for OBJECT
120	Required    []string          `json:"required,omitempty"`   // for OBJECT
121	Items       *Schema           `json:"items,omitempty"`      // for ARRAY
122}
123
124type DataType int
125
126const (
127	DataTypeUNSPECIFIED = DataType(0) // Not specified, should not be used.
128	DataTypeSTRING      = DataType(1)
129	DataTypeNUMBER      = DataType(2)
130	DataTypeINTEGER     = DataType(3)
131	DataTypeBOOLEAN     = DataType(4)
132	DataTypeARRAY       = DataType(5)
133	DataTypeOBJECT      = DataType(6)
134)
135
136const defaultEndpoint = "https://generativelanguage.googleapis.com/v1beta"
137
138type Model struct {
139	Model    string // e.g. "models/gemini-1.5-flash"
140	APIKey   string
141	HTTPC    *http.Client // if nil, http.DefaultClient is used
142	Endpoint string       // if empty, DefaultEndpoint is used
143}
144
145func (m Model) GenerateContent(ctx context.Context, req *Request) (*Response, error) {
146	reqBytes, err := json.Marshal(req)
147	if err != nil {
148		return nil, fmt.Errorf("marshaling request: %w", err)
149	}
150	httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/%s:generateContent?key=%s", m.endpoint(), m.Model, m.APIKey), bytes.NewReader(reqBytes))
151	if err != nil {
152		return nil, fmt.Errorf("creating HTTP request: %w", err)
153	}
154	httpReq.Header.Add("Content-Type", "application/json")
155	httpResp, err := m.httpc().Do(httpReq)
156	if err != nil {
157		return nil, fmt.Errorf("GenerateContent: do: %w", err)
158	}
159	defer httpResp.Body.Close()
160	body, err := io.ReadAll(httpResp.Body)
161	if err != nil {
162		return nil, fmt.Errorf("GenerateContent: reading response body: %w", err)
163	}
164	if httpResp.StatusCode != http.StatusOK {
165		return nil, fmt.Errorf("GenerateContent: HTTP status: %d, %s", httpResp.StatusCode, string(body))
166	}
167	var res Response
168	if err := json.Unmarshal(body, &res); err != nil {
169		return nil, fmt.Errorf("GenerateContent: unmarshaling response: %w, %s", err, string(body))
170	}
171	res.headers = httpResp.Header
172	return &res, nil
173}
174
175func (m Model) endpoint() string {
176	if m.Endpoint != "" {
177		return m.Endpoint
178	}
179	return defaultEndpoint
180}
181
182func (m Model) httpc() *http.Client {
183	if m.HTTPC != nil {
184		return m.HTTPC
185	}
186	return http.DefaultClient
187}