Detailed changes
@@ -83,13 +83,6 @@ type ViewableModel interface {
View() View
}
-// StyledString returns a [Layer] that can be styled with ANSI escape
-// codes. It is used to render text with different colors, styles, and other
-// attributes on the terminal screen.
-func StyledString(s string) *uv.StyledString {
- return uv.NewStyledString(s)
-}
-
// Buffer represents a terminal cell buffer that defines the current state of
// the terminal screen.
type Buffer = uv.Buffer
@@ -127,13 +120,13 @@ func NewView(s any) View {
var view View
switch v := s.(type) {
case string:
- view.Layer = StyledString(v)
+ view.Layer = uv.NewStyledString(v)
case fmt.Stringer:
- view.Layer = StyledString(v.String())
+ view.Layer = uv.NewStyledString(v.String())
case Layer:
view.Layer = v
default:
- view.Layer = StyledString(fmt.Sprintf("%v", v))
+ view.Layer = uv.NewStyledString(fmt.Sprintf("%v", v))
}
return view
}
@@ -867,7 +860,7 @@ func (p *Program) render(model Model) {
case CursorModel:
frame, view.Cursor = model.View()
}
- view.Layer = StyledString(frame)
+ view.Layer = uv.NewStyledString(frame)
view.BackgroundColor = p.lastBgColor
view.ForegroundColor = p.lastFgColor
view.WindowTitle = p.lastWindowTitle
@@ -543,6 +543,7 @@ func (s Style) getAsTransform(propKey) func(string) string {
// line.
func getLines(s string) (lines []string, widest int) {
s = strings.ReplaceAll(s, "\t", " ")
+ s = strings.ReplaceAll(s, "\r\n", "\n")
lines = strings.Split(s, "\n")
for _, l := range lines {
@@ -0,0 +1,5 @@
+Copyright ยฉ 2025 Charmbracelet, Inc.
+
+This project is currently in development and is not yet licensed for public use.
+
+All rights reserved. No license is granted at this time.
@@ -9,18 +9,12 @@
> [!CAUTION]
> This project is in very early development and may change significantly at any moment. Expect no API guarantees as of now.
-Ultraviolet is a Go library for building text-based applications with a focus
-on terminal user interfaces (TUIs). It provides a set of tools and abstractions
-to create interactive terminal applications that can handle user input and
-display dynamic content in a cell-based manner.
+Ultraviolet is a set of of primitives for manipulating terminal emulators with
+a focus on terminal user interfaces (TUIs). It provides a set of tools and
+abstractions for interactive terminal applications that can handle user input
+and display dynamic, cell-based content.
-## Installation
-
-To install UV, you can simply run in your Go project:
-
-```bash
-go get github.com/charmbracelet/ultraviolet
-```
+[Ultraviolet is not yet licensed.](https://github.com/charmbracelet/ultraviolet/raw/main/LICENSE)
---
@@ -378,7 +378,10 @@ func (s *TerminalRenderer) moveCursor(newbuf *Buffer, x, y int, overwrite bool)
s.buf.WriteByte('\r') //nolint:errcheck
s.cur.X, s.cur.Y = 0, 0
}
- s.buf.WriteString(moveCursor(s, newbuf, x, y, overwrite)) //nolint:errcheck
+ seq, scrollHeight := moveCursor(s, newbuf, x, y, overwrite)
+ // If we scrolled the screen, we need to update the scroll height.
+ s.scrollHeight = max(s.scrollHeight, scrollHeight)
+ s.buf.WriteString(seq) //nolint:errcheck
s.cur.X, s.cur.Y = x, y
}
@@ -1316,13 +1319,11 @@ func notLocal(cols, fx, fy, tx, ty int) bool {
//
// It is safe to call this function with a nil [Buffer]. In that case, it won't
// use any optimizations that require the new buffer such as overwrite.
-func relativeCursorMove(s *TerminalRenderer, newbuf *Buffer, fx, fy, tx, ty int, overwrite, useTabs, useBackspace bool) string {
+func relativeCursorMove(s *TerminalRenderer, newbuf *Buffer, fx, fy, tx, ty int, overwrite, useTabs, useBackspace bool) (string, int) {
var seq strings.Builder
- height := -1
+ var scrollHeight int
if newbuf == nil {
overwrite = false // We can't overwrite the current buffer.
- } else {
- height = newbuf.Height()
}
if ty != fy {
@@ -1331,21 +1332,15 @@ func relativeCursorMove(s *TerminalRenderer, newbuf *Buffer, fx, fy, tx, ty int,
yseq = ansi.VerticalPositionAbsolute(ty + 1)
}
- // OPTIM: Use [ansi.LF] and [ansi.ReverseIndex] as optimizations.
-
if ty > fy {
n := ty - fy
if cud := ansi.CursorDown(n); yseq == "" || len(cud) < len(yseq) {
yseq = cud
}
shouldScroll := !s.flags.Contains(tAltScreen) && ty > s.scrollHeight
- if shouldScroll && ty == s.scrollHeight && ty < height {
- n = min(n, height-1-ty)
- }
- if lf := strings.Repeat("\n", n); shouldScroll ||
- ((ty < height || height == -1) && len(lf) < len(yseq)) {
+ if lf := strings.Repeat("\n", n); shouldScroll || len(lf) < len(yseq) {
yseq = lf
- s.scrollHeight = max(s.scrollHeight, ty)
+ scrollHeight = ty
if s.flags.Contains(tMapNewline) {
fx = 0
}
@@ -1468,7 +1463,7 @@ func relativeCursorMove(s *TerminalRenderer, newbuf *Buffer, fx, fy, tx, ty int,
seq.WriteString(xseq)
}
- return seq.String()
+ return seq.String(), scrollHeight
}
// moveCursor moves and returns the cursor movement sequence to move the cursor
@@ -1478,7 +1473,7 @@ func relativeCursorMove(s *TerminalRenderer, newbuf *Buffer, fx, fy, tx, ty int,
//
// It is safe to call this function with a nil [Buffer]. In that case, it won't
// use any optimizations that require the new buffer such as overwrite.
-func moveCursor(s *TerminalRenderer, newbuf *Buffer, x, y int, overwrite bool) (seq string) {
+func moveCursor(s *TerminalRenderer, newbuf *Buffer, x, y int, overwrite bool) (seq string, scrollHeight int) {
fx, fy := s.cur.X, s.cur.Y
if !s.flags.Contains(tRelativeCursor) {
@@ -1496,7 +1491,7 @@ func moveCursor(s *TerminalRenderer, newbuf *Buffer, x, y int, overwrite bool) (
// Method #0: Use [ansi.CUP] if the distance is long.
seq = ansi.CursorPosition(x+1, y+1)
if fx == -1 || fy == -1 || width == -1 || notLocal(width, fx, fy, x, y) {
- return
+ return seq, 0
}
}
@@ -1520,27 +1515,32 @@ func moveCursor(s *TerminalRenderer, newbuf *Buffer, x, y int, overwrite bool) (
useBackspace := i&1 != 0
// Method #1: Use local movement sequences.
- nseq := relativeCursorMove(s, newbuf, fx, fy, x, y, overwrite, useHardTabs, useBackspace)
- if (i == 0 && len(seq) == 0) || len(nseq) < len(seq) {
- seq = nseq
+ nseq1, nscrollHeight1 := relativeCursorMove(s, newbuf, fx, fy, x, y, overwrite, useHardTabs, useBackspace)
+ if (i == 0 && len(seq) == 0) || len(nseq1) < len(seq) {
+ seq = nseq1
+ scrollHeight = max(scrollHeight, nscrollHeight1)
}
// Method #2: Use [ansi.CR] and local movement sequences.
- nseq = "\r" + relativeCursorMove(s, newbuf, 0, fy, x, y, overwrite, useHardTabs, useBackspace)
- if len(nseq) < len(seq) {
- seq = nseq
+ nseq2, nscrollHeight2 := relativeCursorMove(s, newbuf, 0, fy, x, y, overwrite, useHardTabs, useBackspace)
+ nseq2 = "\r" + nseq2
+ if len(nseq2) < len(seq) {
+ seq = nseq2
+ scrollHeight = max(scrollHeight, nscrollHeight2)
}
if !s.flags.Contains(tRelativeCursor) {
// Method #3: Use [ansi.CursorHomePosition] and local movement sequences.
- nseq = ansi.CursorHomePosition + relativeCursorMove(s, newbuf, 0, 0, x, y, overwrite, useHardTabs, useBackspace)
- if len(nseq) < len(seq) {
- seq = nseq
+ nseq3, nscrollHeight3 := relativeCursorMove(s, newbuf, 0, 0, x, y, overwrite, useHardTabs, useBackspace)
+ nseq3 = ansi.CursorHomePosition + nseq3
+ if len(nseq3) < len(seq) {
+ seq = nseq3
+ scrollHeight = max(scrollHeight, nscrollHeight3)
}
}
}
- return
+ return seq, scrollHeight
}
// xtermCaps returns whether the terminal is xterm-like. This means that the
@@ -24,9 +24,6 @@ const newIndex = -1
// updateHashmap updates the hashmap with the new hash value.
func (s *TerminalRenderer) updateHashmap(newbuf *Buffer) {
height := newbuf.Height()
- if s.hashtab == nil || height >= len(s.hashtab) {
- s.hashtab = make([]hashmap, (height+1)*2)
- }
if len(s.oldhash) >= height && len(s.newhash) >= height {
// rehash changed lines
@@ -53,10 +50,7 @@ func (s *TerminalRenderer) updateHashmap(newbuf *Buffer) {
}
}
- for i := 0; i < len(s.hashtab); i++ {
- s.hashtab[i] = hashmap{}
- }
-
+ s.hashtab = make([]hashmap, (height+1)*2)
for i := 0; i < height; i++ {
hashval := s.oldhash[i]
@@ -11,3 +11,8 @@ import "os"
func OpenTTY() (inTty, outTty *os.File, err error) {
return openTTY()
}
+
+// Suspend suspends the current process group.
+func Suspend() error {
+ return suspend()
+}
@@ -8,3 +8,7 @@ import "os"
func openTTY() (*os.File, *os.File, error) {
return nil, nil, ErrPlatformNotSupported
}
+
+func suspend() error {
+ return ErrPlatformNotSupported
+}
@@ -3,7 +3,11 @@
package uv
-import "os"
+import (
+ "os"
+ "os/signal"
+ "syscall"
+)
func openTTY() (inTty, outTty *os.File, err error) {
f, err := os.OpenFile("/dev/tty", os.O_RDWR, 0)
@@ -12,3 +16,13 @@ func openTTY() (inTty, outTty *os.File, err error) {
}
return f, f, nil
}
+
+func suspend() (err error) {
+ // Send SIGTSTP to the entire process group.
+ c := make(chan os.Signal, 1)
+ signal.Notify(c, syscall.SIGCONT)
+ err = syscall.Kill(0, syscall.SIGTSTP)
+ // blocks until a CONT happens...
+ <-c
+ return
+}
@@ -19,3 +19,9 @@ func openTTY() (inTty, outTty *os.File, err error) {
}
return inTty, outTty, nil
}
+
+func suspend() (err error) {
+ // On Windows, suspending the process group is not supported in the same
+ // way as Unix-like systems.
+ return nil
+}
@@ -1,84 +1,434 @@
-// Package client provides MCP (Model Control Protocol) client implementations.
package client
import (
"context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
-// MCPClient represents an MCP client interface
-type MCPClient interface {
- // Initialize sends the initial connection request to the server
- Initialize(
- ctx context.Context,
- request mcp.InitializeRequest,
- ) (*mcp.InitializeResult, error)
-
- // Ping checks if the server is alive
- Ping(ctx context.Context) error
-
- // ListResources requests a list of available resources from the server
- ListResources(
- ctx context.Context,
- request mcp.ListResourcesRequest,
- ) (*mcp.ListResourcesResult, error)
-
- // ListResourceTemplates requests a list of available resource templates from the server
- ListResourceTemplates(
- ctx context.Context,
- request mcp.ListResourceTemplatesRequest,
- ) (*mcp.ListResourceTemplatesResult,
- error)
-
- // ReadResource reads a specific resource from the server
- ReadResource(
- ctx context.Context,
- request mcp.ReadResourceRequest,
- ) (*mcp.ReadResourceResult, error)
-
- // Subscribe requests notifications for changes to a specific resource
- Subscribe(ctx context.Context, request mcp.SubscribeRequest) error
-
- // Unsubscribe cancels notifications for a specific resource
- Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error
-
- // ListPrompts requests a list of available prompts from the server
- ListPrompts(
- ctx context.Context,
- request mcp.ListPromptsRequest,
- ) (*mcp.ListPromptsResult, error)
-
- // GetPrompt retrieves a specific prompt from the server
- GetPrompt(
- ctx context.Context,
- request mcp.GetPromptRequest,
- ) (*mcp.GetPromptResult, error)
-
- // ListTools requests a list of available tools from the server
- ListTools(
- ctx context.Context,
- request mcp.ListToolsRequest,
- ) (*mcp.ListToolsResult, error)
-
- // CallTool invokes a specific tool on the server
- CallTool(
- ctx context.Context,
- request mcp.CallToolRequest,
- ) (*mcp.CallToolResult, error)
-
- // SetLevel sets the logging level for the server
- SetLevel(ctx context.Context, request mcp.SetLevelRequest) error
-
- // Complete requests completion options for a given argument
- Complete(
- ctx context.Context,
- request mcp.CompleteRequest,
- ) (*mcp.CompleteResult, error)
-
- // Close client connection and cleanup resources
- Close() error
-
- // OnNotification registers a handler for notifications
- OnNotification(handler func(notification mcp.JSONRPCNotification))
+// Client implements the MCP client.
+type Client struct {
+ transport transport.Interface
+
+ initialized bool
+ notifications []func(mcp.JSONRPCNotification)
+ notifyMu sync.RWMutex
+ requestID atomic.Int64
+ clientCapabilities mcp.ClientCapabilities
+ serverCapabilities mcp.ServerCapabilities
+}
+
+type ClientOption func(*Client)
+
+// WithClientCapabilities sets the client capabilities for the client.
+func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
+ return func(c *Client) {
+ c.clientCapabilities = capabilities
+ }
+}
+
+// NewClient creates a new MCP client with the given transport.
+// Usage:
+//
+// stdio := transport.NewStdio("./mcp_server", nil, "xxx")
+// client, err := NewClient(stdio)
+// if err != nil {
+// log.Fatalf("Failed to create client: %v", err)
+// }
+func NewClient(transport transport.Interface, options ...ClientOption) *Client {
+ client := &Client{
+ transport: transport,
+ }
+
+ for _, opt := range options {
+ opt(client)
+ }
+
+ return client
+}
+
+// Start initiates the connection to the server.
+// Must be called before using the client.
+func (c *Client) Start(ctx context.Context) error {
+ if c.transport == nil {
+ return fmt.Errorf("transport is nil")
+ }
+ err := c.transport.Start(ctx)
+ if err != nil {
+ return err
+ }
+
+ c.transport.SetNotificationHandler(func(notification mcp.JSONRPCNotification) {
+ c.notifyMu.RLock()
+ defer c.notifyMu.RUnlock()
+ for _, handler := range c.notifications {
+ handler(notification)
+ }
+ })
+ return nil
+}
+
+// Close shuts down the client and closes the transport.
+func (c *Client) Close() error {
+ return c.transport.Close()
+}
+
+// OnNotification registers a handler function to be called when notifications are received.
+// Multiple handlers can be registered and will be called in the order they were added.
+func (c *Client) OnNotification(
+ handler func(notification mcp.JSONRPCNotification),
+) {
+ c.notifyMu.Lock()
+ defer c.notifyMu.Unlock()
+ c.notifications = append(c.notifications, handler)
+}
+
+// sendRequest sends a JSON-RPC request to the server and waits for a response.
+// Returns the raw JSON response message or an error if the request fails.
+func (c *Client) sendRequest(
+ ctx context.Context,
+ method string,
+ params any,
+) (*json.RawMessage, error) {
+ if !c.initialized && method != "initialize" {
+ return nil, fmt.Errorf("client not initialized")
+ }
+
+ id := c.requestID.Add(1)
+
+ request := transport.JSONRPCRequest{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: mcp.NewRequestId(id),
+ Method: method,
+ Params: params,
+ }
+
+ response, err := c.transport.SendRequest(ctx, request)
+ if err != nil {
+ return nil, fmt.Errorf("transport error: %w", err)
+ }
+
+ if response.Error != nil {
+ return nil, errors.New(response.Error.Message)
+ }
+
+ return &response.Result, nil
+}
+
+// Initialize negotiates with the server.
+// Must be called after Start, and before any request methods.
+func (c *Client) Initialize(
+ ctx context.Context,
+ request mcp.InitializeRequest,
+) (*mcp.InitializeResult, error) {
+ // Ensure we send a params object with all required fields
+ params := struct {
+ ProtocolVersion string `json:"protocolVersion"`
+ ClientInfo mcp.Implementation `json:"clientInfo"`
+ Capabilities mcp.ClientCapabilities `json:"capabilities"`
+ }{
+ ProtocolVersion: request.Params.ProtocolVersion,
+ ClientInfo: request.Params.ClientInfo,
+ Capabilities: request.Params.Capabilities, // Will be empty struct if not set
+ }
+
+ response, err := c.sendRequest(ctx, "initialize", params)
+ if err != nil {
+ return nil, err
+ }
+
+ var result mcp.InitializeResult
+ if err := json.Unmarshal(*response, &result); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+
+ // Store serverCapabilities
+ c.serverCapabilities = result.Capabilities
+
+ // Send initialized notification
+ notification := mcp.JSONRPCNotification{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ Notification: mcp.Notification{
+ Method: "notifications/initialized",
+ },
+ }
+
+ err = c.transport.SendNotification(ctx, notification)
+ if err != nil {
+ return nil, fmt.Errorf(
+ "failed to send initialized notification: %w",
+ err,
+ )
+ }
+
+ c.initialized = true
+ return &result, nil
+}
+
+func (c *Client) Ping(ctx context.Context) error {
+ _, err := c.sendRequest(ctx, "ping", nil)
+ return err
+}
+
+// ListResourcesByPage manually list resources by page.
+func (c *Client) ListResourcesByPage(
+ ctx context.Context,
+ request mcp.ListResourcesRequest,
+) (*mcp.ListResourcesResult, error) {
+ result, err := listByPage[mcp.ListResourcesResult](ctx, c, request.PaginatedRequest, "resources/list")
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (c *Client) ListResources(
+ ctx context.Context,
+ request mcp.ListResourcesRequest,
+) (*mcp.ListResourcesResult, error) {
+ result, err := c.ListResourcesByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ for result.NextCursor != "" {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ request.Params.Cursor = result.NextCursor
+ newPageRes, err := c.ListResourcesByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ result.Resources = append(result.Resources, newPageRes.Resources...)
+ result.NextCursor = newPageRes.NextCursor
+ }
+ }
+ return result, nil
+}
+
+func (c *Client) ListResourceTemplatesByPage(
+ ctx context.Context,
+ request mcp.ListResourceTemplatesRequest,
+) (*mcp.ListResourceTemplatesResult, error) {
+ result, err := listByPage[mcp.ListResourceTemplatesResult](ctx, c, request.PaginatedRequest, "resources/templates/list")
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (c *Client) ListResourceTemplates(
+ ctx context.Context,
+ request mcp.ListResourceTemplatesRequest,
+) (*mcp.ListResourceTemplatesResult, error) {
+ result, err := c.ListResourceTemplatesByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ for result.NextCursor != "" {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ request.Params.Cursor = result.NextCursor
+ newPageRes, err := c.ListResourceTemplatesByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ result.ResourceTemplates = append(result.ResourceTemplates, newPageRes.ResourceTemplates...)
+ result.NextCursor = newPageRes.NextCursor
+ }
+ }
+ return result, nil
+}
+
+func (c *Client) ReadResource(
+ ctx context.Context,
+ request mcp.ReadResourceRequest,
+) (*mcp.ReadResourceResult, error) {
+ response, err := c.sendRequest(ctx, "resources/read", request.Params)
+ if err != nil {
+ return nil, err
+ }
+
+ return mcp.ParseReadResourceResult(response)
+}
+
+func (c *Client) Subscribe(
+ ctx context.Context,
+ request mcp.SubscribeRequest,
+) error {
+ _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
+ return err
+}
+
+func (c *Client) Unsubscribe(
+ ctx context.Context,
+ request mcp.UnsubscribeRequest,
+) error {
+ _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
+ return err
+}
+
+func (c *Client) ListPromptsByPage(
+ ctx context.Context,
+ request mcp.ListPromptsRequest,
+) (*mcp.ListPromptsResult, error) {
+ result, err := listByPage[mcp.ListPromptsResult](ctx, c, request.PaginatedRequest, "prompts/list")
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (c *Client) ListPrompts(
+ ctx context.Context,
+ request mcp.ListPromptsRequest,
+) (*mcp.ListPromptsResult, error) {
+ result, err := c.ListPromptsByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ for result.NextCursor != "" {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ request.Params.Cursor = result.NextCursor
+ newPageRes, err := c.ListPromptsByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ result.Prompts = append(result.Prompts, newPageRes.Prompts...)
+ result.NextCursor = newPageRes.NextCursor
+ }
+ }
+ return result, nil
+}
+
+func (c *Client) GetPrompt(
+ ctx context.Context,
+ request mcp.GetPromptRequest,
+) (*mcp.GetPromptResult, error) {
+ response, err := c.sendRequest(ctx, "prompts/get", request.Params)
+ if err != nil {
+ return nil, err
+ }
+
+ return mcp.ParseGetPromptResult(response)
+}
+
+func (c *Client) ListToolsByPage(
+ ctx context.Context,
+ request mcp.ListToolsRequest,
+) (*mcp.ListToolsResult, error) {
+ result, err := listByPage[mcp.ListToolsResult](ctx, c, request.PaginatedRequest, "tools/list")
+ if err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func (c *Client) ListTools(
+ ctx context.Context,
+ request mcp.ListToolsRequest,
+) (*mcp.ListToolsResult, error) {
+ result, err := c.ListToolsByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ for result.NextCursor != "" {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ request.Params.Cursor = result.NextCursor
+ newPageRes, err := c.ListToolsByPage(ctx, request)
+ if err != nil {
+ return nil, err
+ }
+ result.Tools = append(result.Tools, newPageRes.Tools...)
+ result.NextCursor = newPageRes.NextCursor
+ }
+ }
+ return result, nil
+}
+
+func (c *Client) CallTool(
+ ctx context.Context,
+ request mcp.CallToolRequest,
+) (*mcp.CallToolResult, error) {
+ response, err := c.sendRequest(ctx, "tools/call", request.Params)
+ if err != nil {
+ return nil, err
+ }
+
+ return mcp.ParseCallToolResult(response)
+}
+
+func (c *Client) SetLevel(
+ ctx context.Context,
+ request mcp.SetLevelRequest,
+) error {
+ _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
+ return err
+}
+
+func (c *Client) Complete(
+ ctx context.Context,
+ request mcp.CompleteRequest,
+) (*mcp.CompleteResult, error) {
+ response, err := c.sendRequest(ctx, "completion/complete", request.Params)
+ if err != nil {
+ return nil, err
+ }
+
+ var result mcp.CompleteResult
+ if err := json.Unmarshal(*response, &result); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+
+ return &result, nil
+}
+
+func listByPage[T any](
+ ctx context.Context,
+ client *Client,
+ request mcp.PaginatedRequest,
+ method string,
+) (*T, error) {
+ response, err := client.sendRequest(ctx, method, request.Params)
+ if err != nil {
+ return nil, err
+ }
+ var result T
+ if err := json.Unmarshal(*response, &result); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ }
+ return &result, nil
+}
+
+// Helper methods
+
+// GetTransport gives access to the underlying transport layer.
+// Cast it to the specific transport type and obtain the other helper methods.
+func (c *Client) GetTransport() transport.Interface {
+ return c.transport
+}
+
+// GetServerCapabilities returns the server capabilities.
+func (c *Client) GetServerCapabilities() mcp.ServerCapabilities {
+ return c.serverCapabilities
+}
+
+// GetClientCapabilities returns the client capabilities.
+func (c *Client) GetClientCapabilities() mcp.ClientCapabilities {
+ return c.clientCapabilities
}
@@ -0,0 +1,17 @@
+package client
+
+import (
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/client/transport"
+)
+
+// NewStreamableHttpClient is a convenience method that creates a new streamable-http-based MCP client
+// with the given base URL. Returns an error if the URL is invalid.
+func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTPCOption) (*Client, error) {
+ trans, err := transport.NewStreamableHTTP(baseURL, options...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create SSE transport: %w", err)
+ }
+ return NewClient(trans), nil
+}
@@ -0,0 +1,12 @@
+package client
+
+import (
+ "github.com/mark3labs/mcp-go/client/transport"
+ "github.com/mark3labs/mcp-go/server"
+)
+
+// NewInProcessClient connect directly to a mcp server object in the same process
+func NewInProcessClient(server *server.MCPServer) (*Client, error) {
+ inProcessTransport := transport.NewInProcessTransport(server)
+ return NewClient(inProcessTransport), nil
+}
@@ -0,0 +1,109 @@
+// Package client provides MCP (Model Context Protocol) client implementations.
+package client
+
+import (
+ "context"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// MCPClient represents an MCP client interface
+type MCPClient interface {
+ // Initialize sends the initial connection request to the server
+ Initialize(
+ ctx context.Context,
+ request mcp.InitializeRequest,
+ ) (*mcp.InitializeResult, error)
+
+ // Ping checks if the server is alive
+ Ping(ctx context.Context) error
+
+ // ListResourcesByPage manually list resources by page.
+ ListResourcesByPage(
+ ctx context.Context,
+ request mcp.ListResourcesRequest,
+ ) (*mcp.ListResourcesResult, error)
+
+ // ListResources requests a list of available resources from the server
+ ListResources(
+ ctx context.Context,
+ request mcp.ListResourcesRequest,
+ ) (*mcp.ListResourcesResult, error)
+
+ // ListResourceTemplatesByPage manually list resource templates by page.
+ ListResourceTemplatesByPage(
+ ctx context.Context,
+ request mcp.ListResourceTemplatesRequest,
+ ) (*mcp.ListResourceTemplatesResult,
+ error)
+
+ // ListResourceTemplates requests a list of available resource templates from the server
+ ListResourceTemplates(
+ ctx context.Context,
+ request mcp.ListResourceTemplatesRequest,
+ ) (*mcp.ListResourceTemplatesResult,
+ error)
+
+ // ReadResource reads a specific resource from the server
+ ReadResource(
+ ctx context.Context,
+ request mcp.ReadResourceRequest,
+ ) (*mcp.ReadResourceResult, error)
+
+ // Subscribe requests notifications for changes to a specific resource
+ Subscribe(ctx context.Context, request mcp.SubscribeRequest) error
+
+ // Unsubscribe cancels notifications for a specific resource
+ Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error
+
+ // ListPromptsByPage manually list prompts by page.
+ ListPromptsByPage(
+ ctx context.Context,
+ request mcp.ListPromptsRequest,
+ ) (*mcp.ListPromptsResult, error)
+
+ // ListPrompts requests a list of available prompts from the server
+ ListPrompts(
+ ctx context.Context,
+ request mcp.ListPromptsRequest,
+ ) (*mcp.ListPromptsResult, error)
+
+ // GetPrompt retrieves a specific prompt from the server
+ GetPrompt(
+ ctx context.Context,
+ request mcp.GetPromptRequest,
+ ) (*mcp.GetPromptResult, error)
+
+ // ListToolsByPage manually list tools by page.
+ ListToolsByPage(
+ ctx context.Context,
+ request mcp.ListToolsRequest,
+ ) (*mcp.ListToolsResult, error)
+
+ // ListTools requests a list of available tools from the server
+ ListTools(
+ ctx context.Context,
+ request mcp.ListToolsRequest,
+ ) (*mcp.ListToolsResult, error)
+
+ // CallTool invokes a specific tool on the server
+ CallTool(
+ ctx context.Context,
+ request mcp.CallToolRequest,
+ ) (*mcp.CallToolResult, error)
+
+ // SetLevel sets the logging level for the server
+ SetLevel(ctx context.Context, request mcp.SetLevelRequest) error
+
+ // Complete requests completion options for a given argument
+ Complete(
+ ctx context.Context,
+ request mcp.CompleteRequest,
+ ) (*mcp.CompleteResult, error)
+
+ // Close client connection and cleanup resources
+ Close() error
+
+ // OnNotification registers a handler for notifications
+ OnNotification(handler func(notification mcp.JSONRPCNotification))
+}
@@ -0,0 +1,76 @@
+package client
+
+import (
+ "errors"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/client/transport"
+)
+
+// OAuthConfig is a convenience type that wraps transport.OAuthConfig
+type OAuthConfig = transport.OAuthConfig
+
+// Token is a convenience type that wraps transport.Token
+type Token = transport.Token
+
+// TokenStore is a convenience type that wraps transport.TokenStore
+type TokenStore = transport.TokenStore
+
+// MemoryTokenStore is a convenience type that wraps transport.MemoryTokenStore
+type MemoryTokenStore = transport.MemoryTokenStore
+
+// NewMemoryTokenStore is a convenience function that wraps transport.NewMemoryTokenStore
+var NewMemoryTokenStore = transport.NewMemoryTokenStore
+
+// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
+// Returns an error if the URL is invalid.
+func NewOAuthStreamableHttpClient(baseURL string, oauthConfig OAuthConfig, options ...transport.StreamableHTTPCOption) (*Client, error) {
+ // Add OAuth option to the list of options
+ options = append(options, transport.WithHTTPOAuth(oauthConfig))
+
+ trans, err := transport.NewStreamableHTTP(baseURL, options...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create HTTP transport: %w", err)
+ }
+ return NewClient(trans), nil
+}
+
+// NewOAuthStreamableHttpClient creates a new streamable-http-based MCP client with OAuth support.
+// Returns an error if the URL is invalid.
+func NewOAuthSSEClient(baseURL string, oauthConfig OAuthConfig, options ...transport.ClientOption) (*Client, error) {
+ // Add OAuth option to the list of options
+ options = append(options, transport.WithOAuth(oauthConfig))
+
+ trans, err := transport.NewSSE(baseURL, options...)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create SSE transport: %w", err)
+ }
+ return NewClient(trans), nil
+}
+
+// GenerateCodeVerifier generates a code verifier for PKCE
+var GenerateCodeVerifier = transport.GenerateCodeVerifier
+
+// GenerateCodeChallenge generates a code challenge from a code verifier
+var GenerateCodeChallenge = transport.GenerateCodeChallenge
+
+// GenerateState generates a state parameter for OAuth
+var GenerateState = transport.GenerateState
+
+// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
+type OAuthAuthorizationRequiredError = transport.OAuthAuthorizationRequiredError
+
+// IsOAuthAuthorizationRequiredError checks if an error is an OAuthAuthorizationRequiredError
+func IsOAuthAuthorizationRequiredError(err error) bool {
+ var target *OAuthAuthorizationRequiredError
+ return errors.As(err, &target)
+}
+
+// GetOAuthHandler extracts the OAuthHandler from an OAuthAuthorizationRequiredError
+func GetOAuthHandler(err error) *transport.OAuthHandler {
+ var oauthErr *OAuthAuthorizationRequiredError
+ if errors.As(err, &oauthErr) {
+ return oauthErr.Handler
+ }
+ return nil
+}
@@ -1,588 +1,42 @@
package client
import (
- "bufio"
- "bytes"
- "context"
- "encoding/json"
- "errors"
"fmt"
- "io"
"net/http"
"net/url"
- "strings"
- "sync"
- "sync/atomic"
- "time"
- "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/client/transport"
)
-// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE).
-// It maintains a persistent HTTP connection to receive server-pushed events
-// while sending requests over regular HTTP POST calls. The client handles
-// automatic reconnection and message routing between requests and responses.
-type SSEMCPClient struct {
- baseURL *url.URL
- endpoint *url.URL
- httpClient *http.Client
- requestID atomic.Int64
- responses map[int64]chan RPCResponse
- mu sync.RWMutex
- done chan struct{}
- initialized bool
- notifications []func(mcp.JSONRPCNotification)
- notifyMu sync.RWMutex
- endpointChan chan struct{}
- capabilities mcp.ServerCapabilities
- headers map[string]string
- sseReadTimeout time.Duration
+func WithHeaders(headers map[string]string) transport.ClientOption {
+ return transport.WithHeaders(headers)
}
-type ClientOption func(*SSEMCPClient)
-
-func WithHeaders(headers map[string]string) ClientOption {
- return func(sc *SSEMCPClient) {
- sc.headers = headers
- }
+func WithHeaderFunc(headerFunc transport.HTTPHeaderFunc) transport.ClientOption {
+ return transport.WithHeaderFunc(headerFunc)
}
-func WithSSEReadTimeout(timeout time.Duration) ClientOption {
- return func(sc *SSEMCPClient) {
- sc.sseReadTimeout = timeout
- }
+func WithHTTPClient(httpClient *http.Client) transport.ClientOption {
+ return transport.WithHTTPClient(httpClient)
}
// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL.
// Returns an error if the URL is invalid.
-func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) {
- parsedURL, err := url.Parse(baseURL)
- if err != nil {
- return nil, fmt.Errorf("invalid URL: %w", err)
- }
-
- smc := &SSEMCPClient{
- baseURL: parsedURL,
- httpClient: &http.Client{},
- responses: make(map[int64]chan RPCResponse),
- done: make(chan struct{}),
- endpointChan: make(chan struct{}),
- sseReadTimeout: 30 * time.Second,
- headers: make(map[string]string),
- }
-
- for _, opt := range options {
- opt(smc)
- }
-
- return smc, nil
-}
-
-// Start initiates the SSE connection to the server and waits for the endpoint information.
-// Returns an error if the connection fails or times out waiting for the endpoint.
-func (c *SSEMCPClient) Start(ctx context.Context) error {
-
- req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
-
- if err != nil {
-
- return fmt.Errorf("failed to create request: %w", err)
-
- }
-
- req.Header.Set("Accept", "text/event-stream")
- req.Header.Set("Cache-Control", "no-cache")
- req.Header.Set("Connection", "keep-alive")
-
- resp, err := c.httpClient.Do(req)
- if err != nil {
- return fmt.Errorf("failed to connect to SSE stream: %w", err)
- }
-
- if resp.StatusCode != http.StatusOK {
- resp.Body.Close()
- return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
- }
-
- go c.readSSE(resp.Body)
-
- // Wait for the endpoint to be received
-
- select {
- case <-c.endpointChan:
- // Endpoint received, proceed
- case <-ctx.Done():
- return fmt.Errorf("context cancelled while waiting for endpoint")
- case <-time.After(30 * time.Second): // Add a timeout
- return fmt.Errorf("timeout waiting for endpoint")
- }
-
- return nil
-}
-
-// readSSE continuously reads the SSE stream and processes events.
-// It runs until the connection is closed or an error occurs.
-func (c *SSEMCPClient) readSSE(reader io.ReadCloser) {
- defer reader.Close()
-
- br := bufio.NewReader(reader)
- var event, data string
-
- ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout)
- defer cancel()
-
- for {
- select {
- case <-ctx.Done():
- return
- default:
- line, err := br.ReadString('\n')
- if err != nil {
- if err == io.EOF {
- // Process any pending event before exit
- if event != "" && data != "" {
- c.handleSSEEvent(event, data)
- }
- break
- }
- select {
- case <-c.done:
- return
- default:
- fmt.Printf("SSE stream error: %v\n", err)
- return
- }
- }
-
- // Remove only newline markers
- line = strings.TrimRight(line, "\r\n")
- if line == "" {
- // Empty line means end of event
- if event != "" && data != "" {
- c.handleSSEEvent(event, data)
- event = ""
- data = ""
- }
- continue
- }
-
- if strings.HasPrefix(line, "event:") {
- event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
- } else if strings.HasPrefix(line, "data:") {
- data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
- }
- }
- }
-}
-
-// handleSSEEvent processes SSE events based on their type.
-// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
-func (c *SSEMCPClient) handleSSEEvent(event, data string) {
- switch event {
- case "endpoint":
- endpoint, err := c.baseURL.Parse(data)
- if err != nil {
- fmt.Printf("Error parsing endpoint URL: %v\n", err)
- return
- }
- if endpoint.Host != c.baseURL.Host {
- fmt.Printf("Endpoint origin does not match connection origin\n")
- return
- }
- c.endpoint = endpoint
- close(c.endpointChan)
-
- case "message":
- var baseMessage struct {
- JSONRPC string `json:"jsonrpc"`
- ID *int64 `json:"id,omitempty"`
- Method string `json:"method,omitempty"`
- Result json.RawMessage `json:"result,omitempty"`
- Error *struct {
- Code int `json:"code"`
- Message string `json:"message"`
- } `json:"error,omitempty"`
- }
-
- if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
- fmt.Printf("Error unmarshaling message: %v\n", err)
- return
- }
-
- // Handle notification
- if baseMessage.ID == nil {
- var notification mcp.JSONRPCNotification
- if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
- return
- }
- c.notifyMu.RLock()
- for _, handler := range c.notifications {
- handler(notification)
- }
- c.notifyMu.RUnlock()
- return
- }
-
- c.mu.RLock()
- ch, ok := c.responses[*baseMessage.ID]
- c.mu.RUnlock()
-
- if ok {
- if baseMessage.Error != nil {
- ch <- RPCResponse{
- Error: &baseMessage.Error.Message,
- }
- } else {
- ch <- RPCResponse{
- Response: &baseMessage.Result,
- }
- }
- c.mu.Lock()
- delete(c.responses, *baseMessage.ID)
- c.mu.Unlock()
- }
- }
-}
-
-// OnNotification registers a handler function to be called when notifications are received.
-// Multiple handlers can be registered and will be called in the order they were added.
-func (c *SSEMCPClient) OnNotification(
- handler func(notification mcp.JSONRPCNotification),
-) {
- c.notifyMu.Lock()
- defer c.notifyMu.Unlock()
- c.notifications = append(c.notifications, handler)
-}
-
-// sendRequest sends a JSON-RPC request to the server and waits for a response.
-// Returns the raw JSON response message or an error if the request fails.
-func (c *SSEMCPClient) sendRequest(
- ctx context.Context,
- method string,
- params interface{},
-) (*json.RawMessage, error) {
- if !c.initialized && method != "initialize" {
- return nil, fmt.Errorf("client not initialized")
- }
-
- if c.endpoint == nil {
- return nil, fmt.Errorf("endpoint not received")
- }
-
- id := c.requestID.Add(1)
-
- request := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- ID: id,
- Request: mcp.Request{
- Method: method,
- },
- Params: params,
- }
-
- requestBytes, err := json.Marshal(request)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal request: %w", err)
- }
-
- responseChan := make(chan RPCResponse, 1)
- c.mu.Lock()
- c.responses[id] = responseChan
- c.mu.Unlock()
-
- req, err := http.NewRequestWithContext(
- ctx,
- "POST",
- c.endpoint.String(),
- bytes.NewReader(requestBytes),
- )
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
-
- req.Header.Set("Content-Type", "application/json")
- // set custom http headers
- for k, v := range c.headers {
- req.Header.Set(k, v)
- }
-
- resp, err := c.httpClient.Do(req)
- if err != nil {
- return nil, fmt.Errorf("failed to send request: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK &&
- resp.StatusCode != http.StatusAccepted {
- body, _ := io.ReadAll(resp.Body)
- return nil, fmt.Errorf(
- "request failed with status %d: %s",
- resp.StatusCode,
- body,
- )
- }
-
- select {
- case <-ctx.Done():
- c.mu.Lock()
- delete(c.responses, id)
- c.mu.Unlock()
- return nil, ctx.Err()
- case response := <-responseChan:
- if response.Error != nil {
- return nil, errors.New(*response.Error)
- }
- return response.Response, nil
- }
-}
-
-func (c *SSEMCPClient) Initialize(
- ctx context.Context,
- request mcp.InitializeRequest,
-) (*mcp.InitializeResult, error) {
- // Ensure we send a params object with all required fields
- params := struct {
- ProtocolVersion string `json:"protocolVersion"`
- ClientInfo mcp.Implementation `json:"clientInfo"`
- Capabilities mcp.ClientCapabilities `json:"capabilities"`
- }{
- ProtocolVersion: request.Params.ProtocolVersion,
- ClientInfo: request.Params.ClientInfo,
- Capabilities: request.Params.Capabilities, // Will be empty struct if not set
- }
-
- response, err := c.sendRequest(ctx, "initialize", params)
- if err != nil {
- return nil, err
- }
+func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) {
- var result mcp.InitializeResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- // Store capabilities
- c.capabilities = result.Capabilities
-
- // Send initialized notification
- notification := mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: "notifications/initialized",
- },
- }
-
- notificationBytes, err := json.Marshal(notification)
- if err != nil {
- return nil, fmt.Errorf(
- "failed to marshal initialized notification: %w",
- err,
- )
- }
-
- req, err := http.NewRequestWithContext(
- ctx,
- "POST",
- c.endpoint.String(),
- bytes.NewReader(notificationBytes),
- )
- if err != nil {
- return nil, fmt.Errorf("failed to create notification request: %w", err)
- }
-
- req.Header.Set("Content-Type", "application/json")
-
- resp, err := c.httpClient.Do(req)
+ sseTransport, err := transport.NewSSE(baseURL, options...)
if err != nil {
- return nil, fmt.Errorf(
- "failed to send initialized notification: %w",
- err,
- )
+ return nil, fmt.Errorf("failed to create SSE transport: %w", err)
}
- resp.Body.Close()
- c.initialized = true
- return &result, nil
+ return NewClient(sseTransport), nil
}
-func (c *SSEMCPClient) Ping(ctx context.Context) error {
- _, err := c.sendRequest(ctx, "ping", nil)
- return err
-}
-
-func (c *SSEMCPClient) ListResources(
- ctx context.Context,
- request mcp.ListResourcesRequest,
-) (*mcp.ListResourcesResult, error) {
- response, err := c.sendRequest(ctx, "resources/list", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListResourcesResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *SSEMCPClient) ListResourceTemplates(
- ctx context.Context,
- request mcp.ListResourceTemplatesRequest,
-) (*mcp.ListResourceTemplatesResult, error) {
- response, err := c.sendRequest(
- ctx,
- "resources/templates/list",
- request.Params,
- )
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListResourceTemplatesResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *SSEMCPClient) ReadResource(
- ctx context.Context,
- request mcp.ReadResourceRequest,
-) (*mcp.ReadResourceResult, error) {
- response, err := c.sendRequest(ctx, "resources/read", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseReadResourceResult(response)
-}
-
-func (c *SSEMCPClient) Subscribe(
- ctx context.Context,
- request mcp.SubscribeRequest,
-) error {
- _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
- return err
-}
-
-func (c *SSEMCPClient) Unsubscribe(
- ctx context.Context,
- request mcp.UnsubscribeRequest,
-) error {
- _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
- return err
-}
-
-func (c *SSEMCPClient) ListPrompts(
- ctx context.Context,
- request mcp.ListPromptsRequest,
-) (*mcp.ListPromptsResult, error) {
- response, err := c.sendRequest(ctx, "prompts/list", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListPromptsResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *SSEMCPClient) GetPrompt(
- ctx context.Context,
- request mcp.GetPromptRequest,
-) (*mcp.GetPromptResult, error) {
- response, err := c.sendRequest(ctx, "prompts/get", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseGetPromptResult(response)
-}
-
-func (c *SSEMCPClient) ListTools(
- ctx context.Context,
- request mcp.ListToolsRequest,
-) (*mcp.ListToolsResult, error) {
- response, err := c.sendRequest(ctx, "tools/list", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListToolsResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *SSEMCPClient) CallTool(
- ctx context.Context,
- request mcp.CallToolRequest,
-) (*mcp.CallToolResult, error) {
- response, err := c.sendRequest(ctx, "tools/call", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseCallToolResult(response)
-}
-
-func (c *SSEMCPClient) SetLevel(
- ctx context.Context,
- request mcp.SetLevelRequest,
-) error {
- _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
- return err
-}
-
-func (c *SSEMCPClient) Complete(
- ctx context.Context,
- request mcp.CompleteRequest,
-) (*mcp.CompleteResult, error) {
- response, err := c.sendRequest(ctx, "completion/complete", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.CompleteResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-// Helper methods
-
// GetEndpoint returns the current endpoint URL for the SSE connection.
-func (c *SSEMCPClient) GetEndpoint() *url.URL {
- return c.endpoint
-}
-
-// Close shuts down the SSE client connection and cleans up any pending responses.
-// Returns an error if the shutdown process fails.
-func (c *SSEMCPClient) Close() error {
- select {
- case <-c.done:
- return nil // Already closed
- default:
- close(c.done)
- }
-
- // Clean up any pending responses
- c.mu.Lock()
- for _, ch := range c.responses {
- close(ch)
- }
- c.responses = make(map[int64]chan RPCResponse)
- c.mu.Unlock()
-
- return nil
+//
+// Note: This method only works with SSE transport, or it will panic.
+func GetEndpoint(c *Client) *url.URL {
+ t := c.GetTransport()
+ sse := t.(*transport.SSE)
+ return sse.GetEndpoint()
}
@@ -1,441 +1,43 @@
package client
import (
- "bufio"
"context"
- "encoding/json"
- "errors"
"fmt"
"io"
- "os"
- "os/exec"
- "sync"
- "sync/atomic"
- "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/client/transport"
)
-// StdioMCPClient implements the MCPClient interface using stdio communication.
-// It launches a subprocess and communicates with it via standard input/output streams
-// using JSON-RPC messages. The client handles message routing between requests and
-// responses, and supports asynchronous notifications.
-type StdioMCPClient struct {
- cmd *exec.Cmd
- stdin io.WriteCloser
- stdout *bufio.Reader
- requestID atomic.Int64
- responses map[int64]chan RPCResponse
- mu sync.RWMutex
- done chan struct{}
- initialized bool
- notifications []func(mcp.JSONRPCNotification)
- notifyMu sync.RWMutex
- capabilities mcp.ServerCapabilities
-}
-
// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess.
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
// Returns an error if the subprocess cannot be started or the pipes cannot be created.
+//
+// NOTICE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually.
+// This is for backward compatibility.
func NewStdioMCPClient(
command string,
env []string,
args ...string,
-) (*StdioMCPClient, error) {
- cmd := exec.Command(command, args...)
-
- mergedEnv := os.Environ()
- mergedEnv = append(mergedEnv, env...)
-
- cmd.Env = mergedEnv
-
- stdin, err := cmd.StdinPipe()
- if err != nil {
- return nil, fmt.Errorf("failed to create stdin pipe: %w", err)
- }
+) (*Client, error) {
- stdout, err := cmd.StdoutPipe()
+ stdioTransport := transport.NewStdio(command, env, args...)
+ err := stdioTransport.Start(context.Background())
if err != nil {
- return nil, fmt.Errorf("failed to create stdout pipe: %w", err)
- }
-
- client := &StdioMCPClient{
- cmd: cmd,
- stdin: stdin,
- stdout: bufio.NewReader(stdout),
- responses: make(map[int64]chan RPCResponse),
- done: make(chan struct{}),
- }
-
- if err := cmd.Start(); err != nil {
- return nil, fmt.Errorf("failed to start command: %w", err)
- }
-
- // Start reading responses in a goroutine and wait for it to be ready
- ready := make(chan struct{})
- go func() {
- close(ready)
- client.readResponses()
- }()
- <-ready
-
- return client, nil
-}
-
-// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
-// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
-func (c *StdioMCPClient) Close() error {
- close(c.done)
- if err := c.stdin.Close(); err != nil {
- return fmt.Errorf("failed to close stdin: %w", err)
+ return nil, fmt.Errorf("failed to start stdio transport: %w", err)
}
- return c.cmd.Wait()
-}
-
-// OnNotification registers a handler function to be called when notifications are received.
-// Multiple handlers can be registered and will be called in the order they were added.
-func (c *StdioMCPClient) OnNotification(
- handler func(notification mcp.JSONRPCNotification),
-) {
- c.notifyMu.Lock()
- defer c.notifyMu.Unlock()
- c.notifications = append(c.notifications, handler)
-}
-
-// readResponses continuously reads and processes responses from the server's stdout.
-// It handles both responses to requests and notifications, routing them appropriately.
-// Runs until the done channel is closed or an error occurs reading from stdout.
-func (c *StdioMCPClient) readResponses() {
- for {
- select {
- case <-c.done:
- return
- default:
- line, err := c.stdout.ReadString('\n')
- if err != nil {
- if err != io.EOF {
- fmt.Printf("Error reading response: %v\n", err)
- }
- return
- }
-
- var baseMessage struct {
- JSONRPC string `json:"jsonrpc"`
- ID *int64 `json:"id,omitempty"`
- Method string `json:"method,omitempty"`
- Result json.RawMessage `json:"result,omitempty"`
- Error *struct {
- Code int `json:"code"`
- Message string `json:"message"`
- } `json:"error,omitempty"`
- }
- if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
- continue
- }
-
- // Handle notification
- if baseMessage.ID == nil {
- var notification mcp.JSONRPCNotification
- if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
- continue
- }
- c.notifyMu.RLock()
- for _, handler := range c.notifications {
- handler(notification)
- }
- c.notifyMu.RUnlock()
- continue
- }
-
- c.mu.RLock()
- ch, ok := c.responses[*baseMessage.ID]
- c.mu.RUnlock()
-
- if ok {
- if baseMessage.Error != nil {
- ch <- RPCResponse{
- Error: &baseMessage.Error.Message,
- }
- } else {
- ch <- RPCResponse{
- Response: &baseMessage.Result,
- }
- }
- c.mu.Lock()
- delete(c.responses, *baseMessage.ID)
- c.mu.Unlock()
- }
- }
- }
+ return NewClient(stdioTransport), nil
}
-// sendRequest sends a JSON-RPC request to the server and waits for a response.
-// It creates a unique request ID, sends the request over stdin, and waits for
-// the corresponding response or context cancellation.
-// Returns the raw JSON response message or an error if the request fails.
-func (c *StdioMCPClient) sendRequest(
- ctx context.Context,
- method string,
- params interface{},
-) (*json.RawMessage, error) {
- if !c.initialized && method != "initialize" {
- return nil, fmt.Errorf("client not initialized")
- }
-
- id := c.requestID.Add(1)
-
- // Create the complete request structure
- request := mcp.JSONRPCRequest{
- JSONRPC: mcp.JSONRPC_VERSION,
- ID: id,
- Request: mcp.Request{
- Method: method,
- },
- Params: params,
- }
-
- responseChan := make(chan RPCResponse, 1)
- c.mu.Lock()
- c.responses[id] = responseChan
- c.mu.Unlock()
-
- requestBytes, err := json.Marshal(request)
- if err != nil {
- return nil, fmt.Errorf("failed to marshal request: %w", err)
- }
- requestBytes = append(requestBytes, '\n')
-
- if _, err := c.stdin.Write(requestBytes); err != nil {
- return nil, fmt.Errorf("failed to write request: %w", err)
- }
-
- select {
- case <-ctx.Done():
- c.mu.Lock()
- delete(c.responses, id)
- c.mu.Unlock()
- return nil, ctx.Err()
- case response := <-responseChan:
- if response.Error != nil {
- return nil, errors.New(*response.Error)
- }
- return response.Response, nil
- }
-}
-
-func (c *StdioMCPClient) Ping(ctx context.Context) error {
- _, err := c.sendRequest(ctx, "ping", nil)
- return err
-}
-
-func (c *StdioMCPClient) Initialize(
- ctx context.Context,
- request mcp.InitializeRequest,
-) (*mcp.InitializeResult, error) {
- // This structure ensures Capabilities is always included in JSON
- params := struct {
- ProtocolVersion string `json:"protocolVersion"`
- ClientInfo mcp.Implementation `json:"clientInfo"`
- Capabilities mcp.ClientCapabilities `json:"capabilities"`
- }{
- ProtocolVersion: request.Params.ProtocolVersion,
- ClientInfo: request.Params.ClientInfo,
- Capabilities: request.Params.Capabilities, // Will be empty struct if not set
- }
-
- response, err := c.sendRequest(ctx, "initialize", params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.InitializeResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- // Store capabilities
- c.capabilities = result.Capabilities
-
- // Send initialized notification
- notification := mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: "notifications/initialized",
- },
- }
-
- notificationBytes, err := json.Marshal(notification)
- if err != nil {
- return nil, fmt.Errorf(
- "failed to marshal initialized notification: %w",
- err,
- )
- }
- notificationBytes = append(notificationBytes, '\n')
-
- if _, err := c.stdin.Write(notificationBytes); err != nil {
- return nil, fmt.Errorf(
- "failed to send initialized notification: %w",
- err,
- )
- }
-
- c.initialized = true
- return &result, nil
-}
-
-func (c *StdioMCPClient) ListResources(
- ctx context.Context,
- request mcp.ListResourcesRequest,
-) (*mcp.
- ListResourcesResult, error) {
- response, err := c.sendRequest(
- ctx,
- "resources/list",
- request.Params,
- )
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListResourcesResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *StdioMCPClient) ListResourceTemplates(
- ctx context.Context,
- request mcp.ListResourceTemplatesRequest,
-) (*mcp.
- ListResourceTemplatesResult, error) {
- response, err := c.sendRequest(
- ctx,
- "resources/templates/list",
- request.Params,
- )
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListResourceTemplatesResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *StdioMCPClient) ReadResource(
- ctx context.Context,
- request mcp.ReadResourceRequest,
-) (*mcp.ReadResourceResult,
- error) {
- response, err := c.sendRequest(ctx, "resources/read", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseReadResourceResult(response)
-}
-
-func (c *StdioMCPClient) Subscribe(
- ctx context.Context,
- request mcp.SubscribeRequest,
-) error {
- _, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
- return err
-}
-
-func (c *StdioMCPClient) Unsubscribe(
- ctx context.Context,
- request mcp.UnsubscribeRequest,
-) error {
- _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
- return err
-}
-
-func (c *StdioMCPClient) ListPrompts(
- ctx context.Context,
- request mcp.ListPromptsRequest,
-) (*mcp.ListPromptsResult, error) {
- response, err := c.sendRequest(ctx, "prompts/list", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListPromptsResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *StdioMCPClient) GetPrompt(
- ctx context.Context,
- request mcp.GetPromptRequest,
-) (*mcp.GetPromptResult, error) {
- response, err := c.sendRequest(ctx, "prompts/get", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseGetPromptResult(response)
-}
-
-func (c *StdioMCPClient) ListTools(
- ctx context.Context,
- request mcp.ListToolsRequest,
-) (*mcp.ListToolsResult, error) {
- response, err := c.sendRequest(ctx, "tools/list", request.Params)
- if err != nil {
- return nil, err
- }
-
- var result mcp.ListToolsResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
- }
-
- return &result, nil
-}
-
-func (c *StdioMCPClient) CallTool(
- ctx context.Context,
- request mcp.CallToolRequest,
-) (*mcp.CallToolResult, error) {
- response, err := c.sendRequest(ctx, "tools/call", request.Params)
- if err != nil {
- return nil, err
- }
-
- return mcp.ParseCallToolResult(response)
-}
-
-func (c *StdioMCPClient) SetLevel(
- ctx context.Context,
- request mcp.SetLevelRequest,
-) error {
- _, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
- return err
-}
-
-func (c *StdioMCPClient) Complete(
- ctx context.Context,
- request mcp.CompleteRequest,
-) (*mcp.CompleteResult, error) {
- response, err := c.sendRequest(ctx, "completion/complete", request.Params)
- if err != nil {
- return nil, err
- }
+// GetStderr returns a reader for the stderr output of the subprocess.
+// This can be used to capture error messages or logs from the subprocess.
+func GetStderr(c *Client) (io.Reader, bool) {
+ t := c.GetTransport()
- var result mcp.CompleteResult
- if err := json.Unmarshal(*response, &result); err != nil {
- return nil, fmt.Errorf("failed to unmarshal response: %w", err)
+ stdio, ok := t.(*transport.Stdio)
+ if !ok {
+ return nil, false
}
- return &result, nil
+ return stdio.Stderr(), true
}
@@ -0,0 +1,70 @@
+package transport
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "sync"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+)
+
+type InProcessTransport struct {
+ server *server.MCPServer
+
+ onNotification func(mcp.JSONRPCNotification)
+ notifyMu sync.RWMutex
+}
+
+func NewInProcessTransport(server *server.MCPServer) *InProcessTransport {
+ return &InProcessTransport{
+ server: server,
+ }
+}
+
+func (c *InProcessTransport) Start(ctx context.Context) error {
+ return nil
+}
+
+func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) {
+ requestBytes, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+ requestBytes = append(requestBytes, '\n')
+
+ respMessage := c.server.HandleMessage(ctx, requestBytes)
+ respByte, err := json.Marshal(respMessage)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal response message: %w", err)
+ }
+ rpcResp := JSONRPCResponse{}
+ err = json.Unmarshal(respByte, &rpcResp)
+ if err != nil {
+ return nil, fmt.Errorf("failed to unmarshal response message: %w", err)
+ }
+
+ return &rpcResp, nil
+}
+
+func (c *InProcessTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
+ notificationBytes, err := json.Marshal(notification)
+ if err != nil {
+ return fmt.Errorf("failed to marshal notification: %w", err)
+ }
+ notificationBytes = append(notificationBytes, '\n')
+ c.server.HandleMessage(ctx, notificationBytes)
+
+ return nil
+}
+
+func (c *InProcessTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
+ c.notifyMu.Lock()
+ defer c.notifyMu.Unlock()
+ c.onNotification = handler
+}
+
+func (*InProcessTransport) Close() error {
+ return nil
+}
@@ -0,0 +1,50 @@
+package transport
+
+import (
+ "context"
+ "encoding/json"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// HTTPHeaderFunc is a function that extracts header entries from the given context
+// and returns them as key-value pairs. This is typically used to add context values
+// as HTTP headers in outgoing requests.
+type HTTPHeaderFunc func(context.Context) map[string]string
+
+// Interface for the transport layer.
+type Interface interface {
+ // Start the connection. Start should only be called once.
+ Start(ctx context.Context) error
+
+ // SendRequest sends a json RPC request and returns the response synchronously.
+ SendRequest(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error)
+
+ // SendNotification sends a json RPC Notification to the server.
+ SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error
+
+ // SetNotificationHandler sets the handler for notifications.
+ // Any notification before the handler is set will be discarded.
+ SetNotificationHandler(handler func(notification mcp.JSONRPCNotification))
+
+ // Close the connection.
+ Close() error
+}
+
+type JSONRPCRequest struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID mcp.RequestId `json:"id"`
+ Method string `json:"method"`
+ Params any `json:"params,omitempty"`
+}
+
+type JSONRPCResponse struct {
+ JSONRPC string `json:"jsonrpc"`
+ ID mcp.RequestId `json:"id"`
+ Result json.RawMessage `json:"result"`
+ Error *struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data json.RawMessage `json:"data"`
+ } `json:"error"`
+}
@@ -0,0 +1,650 @@
+package transport
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "time"
+)
+
+// OAuthConfig holds the OAuth configuration for the client
+type OAuthConfig struct {
+ // ClientID is the OAuth client ID
+ ClientID string
+ // ClientSecret is the OAuth client secret (for confidential clients)
+ ClientSecret string
+ // RedirectURI is the redirect URI for the OAuth flow
+ RedirectURI string
+ // Scopes is the list of OAuth scopes to request
+ Scopes []string
+ // TokenStore is the storage for OAuth tokens
+ TokenStore TokenStore
+ // AuthServerMetadataURL is the URL to the OAuth server metadata
+ // If empty, the client will attempt to discover it from the base URL
+ AuthServerMetadataURL string
+ // PKCEEnabled enables PKCE for the OAuth flow (recommended for public clients)
+ PKCEEnabled bool
+}
+
+// TokenStore is an interface for storing and retrieving OAuth tokens
+type TokenStore interface {
+ // GetToken returns the current token
+ GetToken() (*Token, error)
+ // SaveToken saves a token
+ SaveToken(token *Token) error
+}
+
+// Token represents an OAuth token
+type Token struct {
+ // AccessToken is the OAuth access token
+ AccessToken string `json:"access_token"`
+ // TokenType is the type of token (usually "Bearer")
+ TokenType string `json:"token_type"`
+ // RefreshToken is the OAuth refresh token
+ RefreshToken string `json:"refresh_token,omitempty"`
+ // ExpiresIn is the number of seconds until the token expires
+ ExpiresIn int64 `json:"expires_in,omitempty"`
+ // Scope is the scope of the token
+ Scope string `json:"scope,omitempty"`
+ // ExpiresAt is the time when the token expires
+ ExpiresAt time.Time `json:"expires_at,omitempty"`
+}
+
+// IsExpired returns true if the token is expired
+func (t *Token) IsExpired() bool {
+ if t.ExpiresAt.IsZero() {
+ return false
+ }
+ return time.Now().After(t.ExpiresAt)
+}
+
+// MemoryTokenStore is a simple in-memory token store
+type MemoryTokenStore struct {
+ token *Token
+ mu sync.RWMutex
+}
+
+// NewMemoryTokenStore creates a new in-memory token store
+func NewMemoryTokenStore() *MemoryTokenStore {
+ return &MemoryTokenStore{}
+}
+
+// GetToken returns the current token
+func (s *MemoryTokenStore) GetToken() (*Token, error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ if s.token == nil {
+ return nil, errors.New("no token available")
+ }
+ return s.token, nil
+}
+
+// SaveToken saves a token
+func (s *MemoryTokenStore) SaveToken(token *Token) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.token = token
+ return nil
+}
+
+// AuthServerMetadata represents the OAuth 2.0 Authorization Server Metadata
+type AuthServerMetadata struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
+ JwksURI string `json:"jwks_uri,omitempty"`
+ ScopesSupported []string `json:"scopes_supported,omitempty"`
+ ResponseTypesSupported []string `json:"response_types_supported"`
+ GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
+ TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
+}
+
+// OAuthHandler handles OAuth authentication for HTTP requests
+type OAuthHandler struct {
+ config OAuthConfig
+ httpClient *http.Client
+ serverMetadata *AuthServerMetadata
+ metadataFetchErr error
+ metadataOnce sync.Once
+ baseURL string
+ expectedState string // Expected state value for CSRF protection
+}
+
+// NewOAuthHandler creates a new OAuth handler
+func NewOAuthHandler(config OAuthConfig) *OAuthHandler {
+ if config.TokenStore == nil {
+ config.TokenStore = NewMemoryTokenStore()
+ }
+
+ return &OAuthHandler{
+ config: config,
+ httpClient: &http.Client{Timeout: 30 * time.Second},
+ }
+}
+
+// GetAuthorizationHeader returns the Authorization header value for a request
+func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, error) {
+ token, err := h.getValidToken(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ // Some auth implementations are strict about token type
+ tokenType := token.TokenType
+ if tokenType == "bearer" {
+ tokenType = "Bearer"
+ }
+
+ return fmt.Sprintf("%s %s", tokenType, token.AccessToken), nil
+}
+
+// getValidToken returns a valid token, refreshing if necessary
+func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
+ token, err := h.config.TokenStore.GetToken()
+ if err == nil && !token.IsExpired() && token.AccessToken != "" {
+ return token, nil
+ }
+
+ // If we have a refresh token, try to use it
+ if err == nil && token.RefreshToken != "" {
+ newToken, err := h.refreshToken(ctx, token.RefreshToken)
+ if err == nil {
+ return newToken, nil
+ }
+ // If refresh fails, continue to authorization flow
+ }
+
+ // We need to get a new token through the authorization flow
+ return nil, ErrOAuthAuthorizationRequired
+}
+
+// refreshToken refreshes an OAuth token
+func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*Token, error) {
+ metadata, err := h.getServerMetadata(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get server metadata: %w", err)
+ }
+
+ data := url.Values{}
+ data.Set("grant_type", "refresh_token")
+ data.Set("refresh_token", refreshToken)
+ data.Set("client_id", h.config.ClientID)
+ if h.config.ClientSecret != "" {
+ data.Set("client_secret", h.config.ClientSecret)
+ }
+
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodPost,
+ metadata.TokenEndpoint,
+ strings.NewReader(data.Encode()),
+ )
+ if err != nil {
+ return nil, fmt.Errorf("failed to create refresh token request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send refresh token request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, extractOAuthError(body, resp.StatusCode, "refresh token request failed")
+ }
+
+ var tokenResp Token
+ if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to decode token response: %w", err)
+ }
+
+ // Set expiration time
+ if tokenResp.ExpiresIn > 0 {
+ tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+ }
+
+ // If no new refresh token is provided, keep the old one
+ oldToken, _ := h.config.TokenStore.GetToken()
+ if tokenResp.RefreshToken == "" && oldToken != nil {
+ tokenResp.RefreshToken = oldToken.RefreshToken
+ }
+
+ // Save the token
+ if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
+ return nil, fmt.Errorf("failed to save token: %w", err)
+ }
+
+ return &tokenResp, nil
+}
+
+// RefreshToken is a public wrapper for refreshToken
+func (h *OAuthHandler) RefreshToken(ctx context.Context, refreshToken string) (*Token, error) {
+ return h.refreshToken(ctx, refreshToken)
+}
+
+// GetClientID returns the client ID
+func (h *OAuthHandler) GetClientID() string {
+ return h.config.ClientID
+}
+
+// extractOAuthError attempts to parse an OAuth error response from the response body
+func extractOAuthError(body []byte, statusCode int, context string) error {
+ // Try to parse the error as an OAuth error response
+ var oauthErr OAuthError
+ if err := json.Unmarshal(body, &oauthErr); err == nil && oauthErr.ErrorCode != "" {
+ return fmt.Errorf("%s: %w", context, oauthErr)
+ }
+
+ // If not a valid OAuth error, return the raw response
+ return fmt.Errorf("%s with status %d: %s", context, statusCode, body)
+}
+
+// GetClientSecret returns the client secret
+func (h *OAuthHandler) GetClientSecret() string {
+ return h.config.ClientSecret
+}
+
+// SetBaseURL sets the base URL for the API server
+func (h *OAuthHandler) SetBaseURL(baseURL string) {
+ h.baseURL = baseURL
+}
+
+// GetExpectedState returns the expected state value (for testing purposes)
+func (h *OAuthHandler) GetExpectedState() string {
+ return h.expectedState
+}
+
+// OAuthError represents a standard OAuth 2.0 error response
+type OAuthError struct {
+ ErrorCode string `json:"error"`
+ ErrorDescription string `json:"error_description,omitempty"`
+ ErrorURI string `json:"error_uri,omitempty"`
+}
+
+// Error implements the error interface
+func (e OAuthError) Error() string {
+ if e.ErrorDescription != "" {
+ return fmt.Sprintf("OAuth error: %s - %s", e.ErrorCode, e.ErrorDescription)
+ }
+ return fmt.Sprintf("OAuth error: %s", e.ErrorCode)
+}
+
+// OAuthProtectedResource represents the response from /.well-known/oauth-protected-resource
+type OAuthProtectedResource struct {
+ AuthorizationServers []string `json:"authorization_servers"`
+ Resource string `json:"resource"`
+ ResourceName string `json:"resource_name,omitempty"`
+}
+
+// getServerMetadata fetches the OAuth server metadata
+func (h *OAuthHandler) getServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
+ h.metadataOnce.Do(func() {
+ // If AuthServerMetadataURL is explicitly provided, use it directly
+ if h.config.AuthServerMetadataURL != "" {
+ h.fetchMetadataFromURL(ctx, h.config.AuthServerMetadataURL)
+ return
+ }
+
+ // Try to discover the authorization server via OAuth Protected Resource
+ // as per RFC 9728 (https://datatracker.ietf.org/doc/html/rfc9728)
+ baseURL, err := h.extractBaseURL()
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to extract base URL: %w", err)
+ return
+ }
+
+ // Try to fetch the OAuth Protected Resource metadata
+ protectedResourceURL := baseURL + "/.well-known/oauth-protected-resource"
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, protectedResourceURL, nil)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to create protected resource request: %w", err)
+ return
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("MCP-Protocol-Version", "2025-03-26")
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to send protected resource request: %w", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ // If we can't get the protected resource metadata, fall back to default endpoints
+ if resp.StatusCode != http.StatusOK {
+ metadata, err := h.getDefaultEndpoints(baseURL)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
+ return
+ }
+ h.serverMetadata = metadata
+ return
+ }
+
+ // Parse the protected resource metadata
+ var protectedResource OAuthProtectedResource
+ if err := json.NewDecoder(resp.Body).Decode(&protectedResource); err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to decode protected resource response: %w", err)
+ return
+ }
+
+ // If no authorization servers are specified, fall back to default endpoints
+ if len(protectedResource.AuthorizationServers) == 0 {
+ metadata, err := h.getDefaultEndpoints(baseURL)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
+ return
+ }
+ h.serverMetadata = metadata
+ return
+ }
+
+ // Use the first authorization server
+ authServerURL := protectedResource.AuthorizationServers[0]
+
+ // Try OpenID Connect discovery first
+ h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/openid-configuration")
+ if h.serverMetadata != nil {
+ return
+ }
+
+ // If OpenID Connect discovery fails, try OAuth Authorization Server Metadata
+ h.fetchMetadataFromURL(ctx, authServerURL+"/.well-known/oauth-authorization-server")
+ if h.serverMetadata != nil {
+ return
+ }
+
+ // If both discovery methods fail, use default endpoints based on the authorization server URL
+ metadata, err := h.getDefaultEndpoints(authServerURL)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to get default endpoints: %w", err)
+ return
+ }
+ h.serverMetadata = metadata
+ })
+
+ if h.metadataFetchErr != nil {
+ return nil, h.metadataFetchErr
+ }
+
+ return h.serverMetadata, nil
+}
+
+// fetchMetadataFromURL fetches and parses OAuth server metadata from a URL
+func (h *OAuthHandler) fetchMetadataFromURL(ctx context.Context, metadataURL string) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, metadataURL, nil)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to create metadata request: %w", err)
+ return
+ }
+
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("MCP-Protocol-Version", "2025-03-26")
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to send metadata request: %w", err)
+ return
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ // If metadata discovery fails, don't set any metadata
+ return
+ }
+
+ var metadata AuthServerMetadata
+ if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
+ h.metadataFetchErr = fmt.Errorf("failed to decode metadata response: %w", err)
+ return
+ }
+
+ h.serverMetadata = &metadata
+}
+
+// extractBaseURL extracts the base URL from the first request
+func (h *OAuthHandler) extractBaseURL() (string, error) {
+ // If we have a base URL from a previous request, use it
+ if h.baseURL != "" {
+ return h.baseURL, nil
+ }
+
+ // Otherwise, we need to infer it from the redirect URI
+ if h.config.RedirectURI == "" {
+ return "", fmt.Errorf("no base URL available and no redirect URI provided")
+ }
+
+ // Parse the redirect URI to extract the authority
+ parsedURL, err := url.Parse(h.config.RedirectURI)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse redirect URI: %w", err)
+ }
+
+ // Use the scheme and host from the redirect URI
+ baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
+ return baseURL, nil
+}
+
+// GetServerMetadata is a public wrapper for getServerMetadata
+func (h *OAuthHandler) GetServerMetadata(ctx context.Context) (*AuthServerMetadata, error) {
+ return h.getServerMetadata(ctx)
+}
+
+// getDefaultEndpoints returns default OAuth endpoints based on the base URL
+func (h *OAuthHandler) getDefaultEndpoints(baseURL string) (*AuthServerMetadata, error) {
+ // Parse the base URL to extract the authority
+ parsedURL, err := url.Parse(baseURL)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse base URL: %w", err)
+ }
+
+ // Discard any path component to get the authorization base URL
+ parsedURL.Path = ""
+ authBaseURL := parsedURL.String()
+
+ // Validate that the URL has a scheme and host
+ if parsedURL.Scheme == "" || parsedURL.Host == "" {
+ return nil, fmt.Errorf("invalid base URL: missing scheme or host in %q", baseURL)
+ }
+
+ return &AuthServerMetadata{
+ Issuer: authBaseURL,
+ AuthorizationEndpoint: authBaseURL + "/authorize",
+ TokenEndpoint: authBaseURL + "/token",
+ RegistrationEndpoint: authBaseURL + "/register",
+ }, nil
+}
+
+// RegisterClient performs dynamic client registration
+func (h *OAuthHandler) RegisterClient(ctx context.Context, clientName string) error {
+ metadata, err := h.getServerMetadata(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get server metadata: %w", err)
+ }
+
+ if metadata.RegistrationEndpoint == "" {
+ return errors.New("server does not support dynamic client registration")
+ }
+
+ // Prepare registration request
+ regRequest := map[string]any{
+ "client_name": clientName,
+ "redirect_uris": []string{h.config.RedirectURI},
+ "token_endpoint_auth_method": "none", // For public clients
+ "grant_types": []string{"authorization_code", "refresh_token"},
+ "response_types": []string{"code"},
+ "scope": strings.Join(h.config.Scopes, " "),
+ }
+
+ // Add client_secret if this is a confidential client
+ if h.config.ClientSecret != "" {
+ regRequest["token_endpoint_auth_method"] = "client_secret_basic"
+ }
+
+ reqBody, err := json.Marshal(regRequest)
+ if err != nil {
+ return fmt.Errorf("failed to marshal registration request: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodPost,
+ metadata.RegistrationEndpoint,
+ bytes.NewReader(reqBody),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create registration request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send registration request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return extractOAuthError(body, resp.StatusCode, "registration request failed")
+ }
+
+ var regResponse struct {
+ ClientID string `json:"client_id"`
+ ClientSecret string `json:"client_secret,omitempty"`
+ }
+
+ if err := json.NewDecoder(resp.Body).Decode(®Response); err != nil {
+ return fmt.Errorf("failed to decode registration response: %w", err)
+ }
+
+ // Update the client configuration
+ h.config.ClientID = regResponse.ClientID
+ if regResponse.ClientSecret != "" {
+ h.config.ClientSecret = regResponse.ClientSecret
+ }
+
+ return nil
+}
+
+// ErrInvalidState is returned when the state parameter doesn't match the expected value
+var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack")
+
+// ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token
+func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error {
+ // Validate the state parameter to prevent CSRF attacks
+ if h.expectedState == "" {
+ return errors.New("no expected state found, authorization flow may not have been initiated properly")
+ }
+
+ if state != h.expectedState {
+ return ErrInvalidState
+ }
+
+ // Clear the expected state after validation
+ defer func() {
+ h.expectedState = ""
+ }()
+
+ metadata, err := h.getServerMetadata(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to get server metadata: %w", err)
+ }
+
+ data := url.Values{}
+ data.Set("grant_type", "authorization_code")
+ data.Set("code", code)
+ data.Set("client_id", h.config.ClientID)
+ data.Set("redirect_uri", h.config.RedirectURI)
+
+ if h.config.ClientSecret != "" {
+ data.Set("client_secret", h.config.ClientSecret)
+ }
+
+ if h.config.PKCEEnabled && codeVerifier != "" {
+ data.Set("code_verifier", codeVerifier)
+ }
+
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodPost,
+ metadata.TokenEndpoint,
+ strings.NewReader(data.Encode()),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create token request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
+ req.Header.Set("Accept", "application/json")
+
+ resp, err := h.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send token request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return extractOAuthError(body, resp.StatusCode, "token request failed")
+ }
+
+ var tokenResp Token
+ if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
+ return fmt.Errorf("failed to decode token response: %w", err)
+ }
+
+ // Set expiration time
+ if tokenResp.ExpiresIn > 0 {
+ tokenResp.ExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
+ }
+
+ // Save the token
+ if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
+ return fmt.Errorf("failed to save token: %w", err)
+ }
+
+ return nil
+}
+
+// GetAuthorizationURL returns the URL for the authorization endpoint
+func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChallenge string) (string, error) {
+ metadata, err := h.getServerMetadata(ctx)
+ if err != nil {
+ return "", fmt.Errorf("failed to get server metadata: %w", err)
+ }
+
+ // Store the state for later validation
+ h.expectedState = state
+
+ params := url.Values{}
+ params.Set("response_type", "code")
+ params.Set("client_id", h.config.ClientID)
+ params.Set("redirect_uri", h.config.RedirectURI)
+ params.Set("state", state)
+
+ if len(h.config.Scopes) > 0 {
+ params.Set("scope", strings.Join(h.config.Scopes, " "))
+ }
+
+ if h.config.PKCEEnabled && codeChallenge != "" {
+ params.Set("code_challenge", codeChallenge)
+ params.Set("code_challenge_method", "S256")
+ }
+
+ return metadata.AuthorizationEndpoint + "?" + params.Encode(), nil
+}
@@ -0,0 +1,68 @@
+package transport
+
+import (
+ "crypto/rand"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "net/url"
+)
+
+// GenerateRandomString generates a random string of the specified length
+func GenerateRandomString(length int) (string, error) {
+ bytes := make([]byte, length)
+ if _, err := rand.Read(bytes); err != nil {
+ return "", err
+ }
+ return base64.RawURLEncoding.EncodeToString(bytes)[:length], nil
+}
+
+// GenerateCodeVerifier generates a code verifier for PKCE
+func GenerateCodeVerifier() (string, error) {
+ // According to RFC 7636, the code verifier should be between 43 and 128 characters
+ return GenerateRandomString(64)
+}
+
+// GenerateCodeChallenge generates a code challenge from a code verifier
+func GenerateCodeChallenge(codeVerifier string) string {
+ // SHA256 hash the code verifier
+ hash := sha256.Sum256([]byte(codeVerifier))
+ // Base64url encode the hash
+ return base64.RawURLEncoding.EncodeToString(hash[:])
+}
+
+// GenerateState generates a state parameter for OAuth
+func GenerateState() (string, error) {
+ return GenerateRandomString(32)
+}
+
+// ValidateRedirectURI validates that a redirect URI is secure
+func ValidateRedirectURI(redirectURI string) error {
+ // According to the spec, redirect URIs must be either localhost URLs or HTTPS URLs
+ if redirectURI == "" {
+ return fmt.Errorf("redirect URI cannot be empty")
+ }
+
+ // Parse the URL
+ parsedURL, err := url.Parse(redirectURI)
+ if err != nil {
+ return fmt.Errorf("invalid redirect URI: %w", err)
+ }
+
+ // Check if it's a localhost URL
+ if parsedURL.Scheme == "http" {
+ hostname := parsedURL.Hostname()
+ // Check for various forms of localhost
+ if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "[::1]" {
+ return nil
+ }
+ return fmt.Errorf("HTTP redirect URI must use localhost or 127.0.0.1")
+ }
+
+ // Check if it's an HTTPS URL
+ if parsedURL.Scheme == "https" {
+ return nil
+ }
+
+ return fmt.Errorf("redirect URI must use either HTTP with localhost or HTTPS")
+}
@@ -0,0 +1,522 @@
+package transport
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
+// It maintains a persistent HTTP connection to receive server-pushed events
+// while sending requests over regular HTTP POST calls. The client handles
+// automatic reconnection and message routing between requests and responses.
+type SSE struct {
+ baseURL *url.URL
+ endpoint *url.URL
+ httpClient *http.Client
+ responses map[string]chan *JSONRPCResponse
+ mu sync.RWMutex
+ onNotification func(mcp.JSONRPCNotification)
+ notifyMu sync.RWMutex
+ endpointChan chan struct{}
+ headers map[string]string
+ headerFunc HTTPHeaderFunc
+
+ started atomic.Bool
+ closed atomic.Bool
+ cancelSSEStream context.CancelFunc
+
+ // OAuth support
+ oauthHandler *OAuthHandler
+}
+
+type ClientOption func(*SSE)
+
+func WithHeaders(headers map[string]string) ClientOption {
+ return func(sc *SSE) {
+ sc.headers = headers
+ }
+}
+
+func WithHeaderFunc(headerFunc HTTPHeaderFunc) ClientOption {
+ return func(sc *SSE) {
+ sc.headerFunc = headerFunc
+ }
+}
+
+func WithHTTPClient(httpClient *http.Client) ClientOption {
+ return func(sc *SSE) {
+ sc.httpClient = httpClient
+ }
+}
+
+func WithOAuth(config OAuthConfig) ClientOption {
+ return func(sc *SSE) {
+ sc.oauthHandler = NewOAuthHandler(config)
+ }
+}
+
+// NewSSE creates a new SSE-based MCP client with the given base URL.
+// Returns an error if the URL is invalid.
+func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
+ parsedURL, err := url.Parse(baseURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid URL: %w", err)
+ }
+
+ smc := &SSE{
+ baseURL: parsedURL,
+ httpClient: &http.Client{},
+ responses: make(map[string]chan *JSONRPCResponse),
+ endpointChan: make(chan struct{}),
+ headers: make(map[string]string),
+ }
+
+ for _, opt := range options {
+ opt(smc)
+ }
+
+ // If OAuth is configured, set the base URL for metadata discovery
+ if smc.oauthHandler != nil {
+ // Extract base URL from server URL for metadata discovery
+ baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
+ smc.oauthHandler.SetBaseURL(baseURL)
+ }
+
+ return smc, nil
+}
+
+// Start initiates the SSE connection to the server and waits for the endpoint information.
+// Returns an error if the connection fails or times out waiting for the endpoint.
+func (c *SSE) Start(ctx context.Context) error {
+
+ if c.started.Load() {
+ return fmt.Errorf("has already started")
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ c.cancelSSEStream = cancel
+
+ req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
+
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+
+ req.Header.Set("Accept", "text/event-stream")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Connection", "keep-alive")
+
+ // set custom http headers
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if err.Error() == "no valid token available, authorization required" {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to connect to SSE stream: %w", err)
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ resp.Body.Close()
+ // Handle OAuth unauthorized error
+ if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
+ }
+
+ go c.readSSE(resp.Body)
+
+ // Wait for the endpoint to be received
+ timeout := time.NewTimer(30 * time.Second)
+ defer timeout.Stop()
+ select {
+ case <-c.endpointChan:
+ // Endpoint received, proceed
+ case <-ctx.Done():
+ return fmt.Errorf("context cancelled while waiting for endpoint")
+ case <-timeout.C: // Add a timeout
+ cancel()
+ return fmt.Errorf("timeout waiting for endpoint")
+ }
+
+ c.started.Store(true)
+ return nil
+}
+
+// readSSE continuously reads the SSE stream and processes events.
+// It runs until the connection is closed or an error occurs.
+func (c *SSE) readSSE(reader io.ReadCloser) {
+ defer reader.Close()
+
+ br := bufio.NewReader(reader)
+ var event, data string
+
+ for {
+ // when close or start's ctx cancel, the reader will be closed
+ // and the for loop will break.
+ line, err := br.ReadString('\n')
+ if err != nil {
+ if err == io.EOF {
+ // Process any pending event before exit
+ if data != "" {
+ // If no event type is specified, use empty string (default event type)
+ if event == "" {
+ event = "message"
+ }
+ c.handleSSEEvent(event, data)
+ }
+ break
+ }
+ if !c.closed.Load() {
+ fmt.Printf("SSE stream error: %v\n", err)
+ }
+ return
+ }
+
+ // Remove only newline markers
+ line = strings.TrimRight(line, "\r\n")
+ if line == "" {
+ // Empty line means end of event
+ if data != "" {
+ // If no event type is specified, use empty string (default event type)
+ if event == "" {
+ event = "message"
+ }
+ c.handleSSEEvent(event, data)
+ event = ""
+ data = ""
+ }
+ continue
+ }
+
+ if strings.HasPrefix(line, "event:") {
+ event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
+ } else if strings.HasPrefix(line, "data:") {
+ data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ }
+ }
+}
+
+// handleSSEEvent processes SSE events based on their type.
+// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication.
+func (c *SSE) handleSSEEvent(event, data string) {
+ switch event {
+ case "endpoint":
+ endpoint, err := c.baseURL.Parse(data)
+ if err != nil {
+ fmt.Printf("Error parsing endpoint URL: %v\n", err)
+ return
+ }
+ if endpoint.Host != c.baseURL.Host {
+ fmt.Printf("Endpoint origin does not match connection origin\n")
+ return
+ }
+ c.endpoint = endpoint
+ close(c.endpointChan)
+
+ case "message":
+ var baseMessage JSONRPCResponse
+ if err := json.Unmarshal([]byte(data), &baseMessage); err != nil {
+ fmt.Printf("Error unmarshaling message: %v\n", err)
+ return
+ }
+
+ // Handle notification
+ if baseMessage.ID.IsNil() {
+ var notification mcp.JSONRPCNotification
+ if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
+ return
+ }
+ c.notifyMu.RLock()
+ if c.onNotification != nil {
+ c.onNotification(notification)
+ }
+ c.notifyMu.RUnlock()
+ return
+ }
+
+ // Create string key for map lookup
+ idKey := baseMessage.ID.String()
+
+ c.mu.RLock()
+ ch, exists := c.responses[idKey]
+ c.mu.RUnlock()
+
+ if exists {
+ ch <- &baseMessage
+ c.mu.Lock()
+ delete(c.responses, idKey)
+ c.mu.Unlock()
+ }
+ }
+}
+
+func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) {
+ c.notifyMu.Lock()
+ defer c.notifyMu.Unlock()
+ c.onNotification = handler
+}
+
+// SendRequest sends a JSON-RPC request to the server and waits for a response.
+// Returns the raw JSON response message or an error if the request fails.
+func (c *SSE) SendRequest(
+ ctx context.Context,
+ request JSONRPCRequest,
+) (*JSONRPCResponse, error) {
+
+ if !c.started.Load() {
+ return nil, fmt.Errorf("transport not started yet")
+ }
+ if c.closed.Load() {
+ return nil, fmt.Errorf("transport has been closed")
+ }
+ if c.endpoint == nil {
+ return nil, fmt.Errorf("endpoint not received")
+ }
+
+ // Marshal request
+ requestBytes, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ // Create HTTP request
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.endpoint.String(), bytes.NewReader(requestBytes))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if err.Error() == "no valid token available, authorization required" {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return nil, fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Create string key for map lookup
+ idKey := request.ID.String()
+
+ // Register response channel
+ responseChan := make(chan *JSONRPCResponse, 1)
+ c.mu.Lock()
+ c.responses[idKey] = responseChan
+ c.mu.Unlock()
+ deleteResponseChan := func() {
+ c.mu.Lock()
+ delete(c.responses, idKey)
+ c.mu.Unlock()
+ }
+
+ // Send request
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ deleteResponseChan()
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+
+ // Drain any outstanding io
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ // Check if we got an error response
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ deleteResponseChan()
+
+ // Handle OAuth unauthorized error
+ if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+
+ return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
+ }
+
+ select {
+ case <-ctx.Done():
+ deleteResponseChan()
+ return nil, ctx.Err()
+ case response, ok := <-responseChan:
+ if ok {
+ return response, nil
+ }
+ return nil, fmt.Errorf("connection has been closed")
+ }
+}
+
+// Close shuts down the SSE client connection and cleans up any pending responses.
+// Returns an error if the shutdown process fails.
+func (c *SSE) Close() error {
+ if !c.closed.CompareAndSwap(false, true) {
+ return nil // Already closed
+ }
+
+ if c.cancelSSEStream != nil {
+ // It could stop the sse stream body, to quit the readSSE loop immediately
+ // Also, it could quit start() immediately if not receiving the endpoint
+ c.cancelSSEStream()
+ }
+
+ // Clean up any pending responses
+ c.mu.Lock()
+ for _, ch := range c.responses {
+ close(ch)
+ }
+ c.responses = make(map[string]chan *JSONRPCResponse)
+ c.mu.Unlock()
+
+ return nil
+}
+
+// SendNotification sends a JSON-RPC notification to the server without expecting a response.
+func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
+ if c.endpoint == nil {
+ return fmt.Errorf("endpoint not received")
+ }
+
+ notificationBytes, err := json.Marshal(notification)
+ if err != nil {
+ return fmt.Errorf("failed to marshal notification: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(
+ ctx,
+ "POST",
+ c.endpoint.String(),
+ bytes.NewReader(notificationBytes),
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create notification request: %w", err)
+ }
+
+ req.Header.Set("Content-Type", "application/json")
+ // Set custom HTTP headers
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if errors.Is(err, ErrOAuthAuthorizationRequired) {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send notification: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ // Handle OAuth unauthorized error
+ if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf(
+ "notification failed with status %d: %s",
+ resp.StatusCode,
+ body,
+ )
+ }
+
+ return nil
+}
+
+// GetEndpoint returns the current endpoint URL for the SSE connection.
+func (c *SSE) GetEndpoint() *url.URL {
+ return c.endpoint
+}
+
+// GetBaseURL returns the base URL set in the SSE constructor.
+func (c *SSE) GetBaseURL() *url.URL {
+ return c.baseURL
+}
+
+// GetOAuthHandler returns the OAuth handler if configured
+func (c *SSE) GetOAuthHandler() *OAuthHandler {
+ return c.oauthHandler
+}
+
+// IsOAuthEnabled returns true if OAuth is enabled
+func (c *SSE) IsOAuthEnabled() bool {
+ return c.oauthHandler != nil
+}
@@ -0,0 +1,288 @@
+package transport
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "sync"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// Stdio implements the transport layer of the MCP protocol using stdio communication.
+// It launches a subprocess and communicates with it via standard input/output streams
+// using JSON-RPC messages. The client handles message routing between requests and
+// responses, and supports asynchronous notifications.
+type Stdio struct {
+ command string
+ args []string
+ env []string
+
+ cmd *exec.Cmd
+ stdin io.WriteCloser
+ stdout *bufio.Reader
+ stderr io.ReadCloser
+ responses map[string]chan *JSONRPCResponse
+ mu sync.RWMutex
+ done chan struct{}
+ onNotification func(mcp.JSONRPCNotification)
+ notifyMu sync.RWMutex
+}
+
+// NewIO returns a new stdio-based transport using existing input, output, and
+// logging streams instead of spawning a subprocess.
+// This is useful for testing and simulating client behavior.
+func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
+ return &Stdio{
+ stdin: output,
+ stdout: bufio.NewReader(input),
+ stderr: logging,
+
+ responses: make(map[string]chan *JSONRPCResponse),
+ done: make(chan struct{}),
+ }
+}
+
+// NewStdio creates a new stdio transport to communicate with a subprocess.
+// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
+// Returns an error if the subprocess cannot be started or the pipes cannot be created.
+func NewStdio(
+ command string,
+ env []string,
+ args ...string,
+) *Stdio {
+
+ client := &Stdio{
+ command: command,
+ args: args,
+ env: env,
+
+ responses: make(map[string]chan *JSONRPCResponse),
+ done: make(chan struct{}),
+ }
+
+ return client
+}
+
+func (c *Stdio) Start(ctx context.Context) error {
+ if err := c.spawnCommand(ctx); err != nil {
+ return err
+ }
+
+ ready := make(chan struct{})
+ go func() {
+ close(ready)
+ c.readResponses()
+ }()
+ <-ready
+
+ return nil
+}
+
+// spawnCommand spawns a new process running c.command.
+func (c *Stdio) spawnCommand(ctx context.Context) error {
+ if c.command == "" {
+ return nil
+ }
+
+ cmd := exec.CommandContext(ctx, c.command, c.args...)
+
+ mergedEnv := os.Environ()
+ mergedEnv = append(mergedEnv, c.env...)
+
+ cmd.Env = mergedEnv
+
+ stdin, err := cmd.StdinPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stdin pipe: %w", err)
+ }
+
+ stdout, err := cmd.StdoutPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stdout pipe: %w", err)
+ }
+
+ stderr, err := cmd.StderrPipe()
+ if err != nil {
+ return fmt.Errorf("failed to create stderr pipe: %w", err)
+ }
+
+ c.cmd = cmd
+ c.stdin = stdin
+ c.stderr = stderr
+ c.stdout = bufio.NewReader(stdout)
+
+ if err := cmd.Start(); err != nil {
+ return fmt.Errorf("failed to start command: %w", err)
+ }
+
+ return nil
+}
+
+// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
+// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
+func (c *Stdio) Close() error {
+ select {
+ case <-c.done:
+ return nil
+ default:
+ }
+ // cancel all in-flight request
+ close(c.done)
+
+ if err := c.stdin.Close(); err != nil {
+ return fmt.Errorf("failed to close stdin: %w", err)
+ }
+ if err := c.stderr.Close(); err != nil {
+ return fmt.Errorf("failed to close stderr: %w", err)
+ }
+
+ if c.cmd != nil {
+ return c.cmd.Wait()
+ }
+
+ return nil
+}
+
+// SetNotificationHandler sets the handler function to be called when a notification is received.
+// Only one handler can be set at a time; setting a new one replaces the previous handler.
+func (c *Stdio) SetNotificationHandler(
+ handler func(notification mcp.JSONRPCNotification),
+) {
+ c.notifyMu.Lock()
+ defer c.notifyMu.Unlock()
+ c.onNotification = handler
+}
+
+// readResponses continuously reads and processes responses from the server's stdout.
+// It handles both responses to requests and notifications, routing them appropriately.
+// Runs until the done channel is closed or an error occurs reading from stdout.
+func (c *Stdio) readResponses() {
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ line, err := c.stdout.ReadString('\n')
+ if err != nil {
+ if err != io.EOF {
+ fmt.Printf("Error reading response: %v\n", err)
+ }
+ return
+ }
+
+ var baseMessage JSONRPCResponse
+ if err := json.Unmarshal([]byte(line), &baseMessage); err != nil {
+ continue
+ }
+
+ // Handle notification
+ if baseMessage.ID.IsNil() {
+ var notification mcp.JSONRPCNotification
+ if err := json.Unmarshal([]byte(line), ¬ification); err != nil {
+ continue
+ }
+ c.notifyMu.RLock()
+ if c.onNotification != nil {
+ c.onNotification(notification)
+ }
+ c.notifyMu.RUnlock()
+ continue
+ }
+
+ // Create string key for map lookup
+ idKey := baseMessage.ID.String()
+
+ c.mu.RLock()
+ ch, exists := c.responses[idKey]
+ c.mu.RUnlock()
+
+ if exists {
+ ch <- &baseMessage
+ c.mu.Lock()
+ delete(c.responses, idKey)
+ c.mu.Unlock()
+ }
+ }
+ }
+}
+
+// SendRequest sends a JSON-RPC request to the server and waits for a response.
+// It creates a unique request ID, sends the request over stdin, and waits for
+// the corresponding response or context cancellation.
+// Returns the raw JSON response message or an error if the request fails.
+func (c *Stdio) SendRequest(
+ ctx context.Context,
+ request JSONRPCRequest,
+) (*JSONRPCResponse, error) {
+ if c.stdin == nil {
+ return nil, fmt.Errorf("stdio client not started")
+ }
+
+ // Marshal request
+ requestBytes, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+ requestBytes = append(requestBytes, '\n')
+
+ // Create string key for map lookup
+ idKey := request.ID.String()
+
+ // Register response channel
+ responseChan := make(chan *JSONRPCResponse, 1)
+ c.mu.Lock()
+ c.responses[idKey] = responseChan
+ c.mu.Unlock()
+ deleteResponseChan := func() {
+ c.mu.Lock()
+ delete(c.responses, idKey)
+ c.mu.Unlock()
+ }
+
+ // Send request
+ if _, err := c.stdin.Write(requestBytes); err != nil {
+ deleteResponseChan()
+ return nil, fmt.Errorf("failed to write request: %w", err)
+ }
+
+ select {
+ case <-ctx.Done():
+ deleteResponseChan()
+ return nil, ctx.Err()
+ case response := <-responseChan:
+ return response, nil
+ }
+}
+
+// SendNotification sends a json RPC Notification to the server.
+func (c *Stdio) SendNotification(
+ ctx context.Context,
+ notification mcp.JSONRPCNotification,
+) error {
+ if c.stdin == nil {
+ return fmt.Errorf("stdio client not started")
+ }
+
+ notificationBytes, err := json.Marshal(notification)
+ if err != nil {
+ return fmt.Errorf("failed to marshal notification: %w", err)
+ }
+ notificationBytes = append(notificationBytes, '\n')
+
+ if _, err := c.stdin.Write(notificationBytes); err != nil {
+ return fmt.Errorf("failed to write notification: %w", err)
+ }
+
+ return nil
+}
+
+// Stderr returns a reader for the stderr output of the subprocess.
+// This can be used to capture error messages or logs from the subprocess.
+func (c *Stdio) Stderr() io.Reader {
+ return c.stderr
+}
@@ -0,0 +1,515 @@
+package transport
+
+import (
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "mime"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+type StreamableHTTPCOption func(*StreamableHTTP)
+
+// WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport.
+func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.httpClient = client
+ }
+}
+
+func WithHTTPHeaders(headers map[string]string) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.headers = headers
+ }
+}
+
+func WithHTTPHeaderFunc(headerFunc HTTPHeaderFunc) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.headerFunc = headerFunc
+ }
+}
+
+// WithHTTPTimeout sets the timeout for a HTTP request and stream.
+func WithHTTPTimeout(timeout time.Duration) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.httpClient.Timeout = timeout
+ }
+}
+
+// WithHTTPOAuth enables OAuth authentication for the client.
+func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption {
+ return func(sc *StreamableHTTP) {
+ sc.oauthHandler = NewOAuthHandler(config)
+ }
+}
+
+// StreamableHTTP implements Streamable HTTP transport.
+//
+// It transmits JSON-RPC messages over individual HTTP requests. One message per request.
+// The HTTP response body can either be a single JSON-RPC response,
+// or an upgraded SSE stream that concludes with a JSON-RPC response for the same request.
+//
+// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
+//
+// The current implementation does not support the following features:
+// - batching
+// - continuously listening for server notifications when no request is in flight
+// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server)
+// - resuming stream
+// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery)
+// - server -> client request
+type StreamableHTTP struct {
+ serverURL *url.URL
+ httpClient *http.Client
+ headers map[string]string
+ headerFunc HTTPHeaderFunc
+
+ sessionID atomic.Value // string
+
+ notificationHandler func(mcp.JSONRPCNotification)
+ notifyMu sync.RWMutex
+
+ closed chan struct{}
+
+ // OAuth support
+ oauthHandler *OAuthHandler
+}
+
+// NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL.
+// Returns an error if the URL is invalid.
+func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*StreamableHTTP, error) {
+ parsedURL, err := url.Parse(serverURL)
+ if err != nil {
+ return nil, fmt.Errorf("invalid URL: %w", err)
+ }
+
+ smc := &StreamableHTTP{
+ serverURL: parsedURL,
+ httpClient: &http.Client{},
+ headers: make(map[string]string),
+ closed: make(chan struct{}),
+ }
+ smc.sessionID.Store("") // set initial value to simplify later usage
+
+ for _, opt := range options {
+ opt(smc)
+ }
+
+ // If OAuth is configured, set the base URL for metadata discovery
+ if smc.oauthHandler != nil {
+ // Extract base URL from server URL for metadata discovery
+ baseURL := fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host)
+ smc.oauthHandler.SetBaseURL(baseURL)
+ }
+
+ return smc, nil
+}
+
+// Start initiates the HTTP connection to the server.
+func (c *StreamableHTTP) Start(ctx context.Context) error {
+ // For Streamable HTTP, we don't need to establish a persistent connection
+ return nil
+}
+
+// Close closes the all the HTTP connections to the server.
+func (c *StreamableHTTP) Close() error {
+ select {
+ case <-c.closed:
+ return nil
+ default:
+ }
+ // Cancel all in-flight requests
+ close(c.closed)
+
+ sessionId := c.sessionID.Load().(string)
+ if sessionId != "" {
+ c.sessionID.Store("")
+
+ // notify server session closed
+ go func() {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil)
+ if err != nil {
+ fmt.Printf("failed to create close request\n: %v", err)
+ return
+ }
+ req.Header.Set(headerKeySessionID, sessionId)
+ res, err := c.httpClient.Do(req)
+ if err != nil {
+ fmt.Printf("failed to send close request\n: %v", err)
+ return
+ }
+ res.Body.Close()
+ }()
+ }
+
+ return nil
+}
+
+const (
+ headerKeySessionID = "Mcp-Session-Id"
+)
+
+// ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required
+var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required")
+
+// OAuthAuthorizationRequiredError is returned when OAuth authorization is required
+type OAuthAuthorizationRequiredError struct {
+ Handler *OAuthHandler
+}
+
+func (e *OAuthAuthorizationRequiredError) Error() string {
+ return ErrOAuthAuthorizationRequired.Error()
+}
+
+func (e *OAuthAuthorizationRequiredError) Unwrap() error {
+ return ErrOAuthAuthorizationRequired
+}
+
+// SendRequest sends a JSON-RPC request to the server and waits for a response.
+// Returns the raw JSON response message or an error if the request fails.
+func (c *StreamableHTTP) SendRequest(
+ ctx context.Context,
+ request JSONRPCRequest,
+) (*JSONRPCResponse, error) {
+
+ // Create a combined context that could be canceled when the client is closed
+ newCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ go func() {
+ select {
+ case <-c.closed:
+ cancel()
+ case <-newCtx.Done():
+ // The original context was canceled, no need to do anything
+ }
+ }()
+ ctx = newCtx
+
+ // Marshal request
+ requestBody, err := json.Marshal(request)
+ if err != nil {
+ return nil, fmt.Errorf("failed to marshal request: %w", err)
+ }
+
+ // Create HTTP request
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
+ if err != nil {
+ return nil, fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json, text/event-stream")
+ sessionID := c.sessionID.Load()
+ if sessionID != "" {
+ req.Header.Set(headerKeySessionID, sessionID.(string))
+ }
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if err.Error() == "no valid token available, authorization required" {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return nil, fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Send request
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ // Check if we got an error response
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ // handle session closed
+ if resp.StatusCode == http.StatusNotFound {
+ c.sessionID.CompareAndSwap(sessionID, "")
+ return nil, fmt.Errorf("session terminated (404). need to re-initialize")
+ }
+
+ // Handle OAuth unauthorized error
+ if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
+ return nil, &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+
+ // handle error response
+ var errResponse JSONRPCResponse
+ body, _ := io.ReadAll(resp.Body)
+ if err := json.Unmarshal(body, &errResponse); err == nil {
+ return &errResponse, nil
+ }
+ return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body)
+ }
+
+ if request.Method == string(mcp.MethodInitialize) {
+ // saved the received session ID in the response
+ // empty session ID is allowed
+ if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" {
+ c.sessionID.Store(sessionID)
+ }
+ }
+
+ // Handle different response types
+ mediaType, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
+ switch mediaType {
+ case "application/json":
+ // Single response
+ var response JSONRPCResponse
+ if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
+ return nil, fmt.Errorf("failed to decode response: %w", err)
+ }
+
+ // should not be a notification
+ if response.ID.IsNil() {
+ return nil, fmt.Errorf("response should contain RPC id: %v", response)
+ }
+
+ return &response, nil
+
+ case "text/event-stream":
+ // Server is using SSE for streaming responses
+ return c.handleSSEResponse(ctx, resp.Body)
+
+ default:
+ return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type"))
+ }
+}
+
+// handleSSEResponse processes an SSE stream for a specific request.
+// It returns the final result for the request once received, or an error.
+func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) {
+
+ // Create a channel for this specific request
+ responseChan := make(chan *JSONRPCResponse, 1)
+
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ // Start a goroutine to process the SSE stream
+ go func() {
+ // only close responseChan after readingSSE()
+ defer close(responseChan)
+
+ c.readSSE(ctx, reader, func(event, data string) {
+
+ // (unsupported: batching)
+
+ var message JSONRPCResponse
+ if err := json.Unmarshal([]byte(data), &message); err != nil {
+ fmt.Printf("failed to unmarshal message: %v\n", err)
+ return
+ }
+
+ // Handle notification
+ if message.ID.IsNil() {
+ var notification mcp.JSONRPCNotification
+ if err := json.Unmarshal([]byte(data), ¬ification); err != nil {
+ fmt.Printf("failed to unmarshal notification: %v\n", err)
+ return
+ }
+ c.notifyMu.RLock()
+ if c.notificationHandler != nil {
+ c.notificationHandler(notification)
+ }
+ c.notifyMu.RUnlock()
+ return
+ }
+
+ responseChan <- &message
+ })
+ }()
+
+ // Wait for the response or context cancellation
+ select {
+ case response := <-responseChan:
+ if response == nil {
+ return nil, fmt.Errorf("unexpected nil response")
+ }
+ return response, nil
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+}
+
+// readSSE reads the SSE stream(reader) and calls the handler for each event and data pair.
+// It will end when the reader is closed (or the context is done).
+func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, handler func(event, data string)) {
+ defer reader.Close()
+
+ br := bufio.NewReader(reader)
+ var event, data string
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ line, err := br.ReadString('\n')
+ if err != nil {
+ if err == io.EOF {
+ // Process any pending event before exit
+ if data != "" {
+ // If no event type is specified, use empty string (default event type)
+ if event == "" {
+ event = "message"
+ }
+ handler(event, data)
+ }
+ return
+ }
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ fmt.Printf("SSE stream error: %v\n", err)
+ return
+ }
+ }
+
+ // Remove only newline markers
+ line = strings.TrimRight(line, "\r\n")
+ if line == "" {
+ // Empty line means end of event
+ if data != "" {
+ // If no event type is specified, use empty string (default event type)
+ if event == "" {
+ event = "message"
+ }
+ handler(event, data)
+ event = ""
+ data = ""
+ }
+ continue
+ }
+
+ if strings.HasPrefix(line, "event:") {
+ event = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
+ } else if strings.HasPrefix(line, "data:") {
+ data = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
+ }
+ }
+ }
+}
+
+func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error {
+
+ // Marshal request
+ requestBody, err := json.Marshal(notification)
+ if err != nil {
+ return fmt.Errorf("failed to marshal notification: %w", err)
+ }
+
+ // Create HTTP request
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody))
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+
+ // Set headers
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json, text/event-stream")
+ if sessionID := c.sessionID.Load(); sessionID != "" {
+ req.Header.Set(headerKeySessionID, sessionID.(string))
+ }
+ for k, v := range c.headers {
+ req.Header.Set(k, v)
+ }
+
+ // Add OAuth authorization if configured
+ if c.oauthHandler != nil {
+ authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx)
+ if err != nil {
+ // If we get an authorization error, return a specific error that can be handled by the client
+ if errors.Is(err, ErrOAuthAuthorizationRequired) {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+ return fmt.Errorf("failed to get authorization header: %w", err)
+ }
+ req.Header.Set("Authorization", authHeader)
+ }
+
+ if c.headerFunc != nil {
+ for k, v := range c.headerFunc(ctx) {
+ req.Header.Set(k, v)
+ }
+ }
+
+ // Send request
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return fmt.Errorf("failed to send request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
+ // Handle OAuth unauthorized error
+ if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil {
+ return &OAuthAuthorizationRequiredError{
+ Handler: c.oauthHandler,
+ }
+ }
+
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf(
+ "notification failed with status %d: %s",
+ resp.StatusCode,
+ body,
+ )
+ }
+
+ return nil
+}
+
+func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotification)) {
+ c.notifyMu.Lock()
+ defer c.notifyMu.Unlock()
+ c.notificationHandler = handler
+}
+
+func (c *StreamableHTTP) GetSessionId() string {
+ return c.sessionID.Load().(string)
+}
+
+// GetOAuthHandler returns the OAuth handler if configured
+func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler {
+ return c.oauthHandler
+}
+
+// IsOAuthEnabled returns true if OAuth is enabled
+func (c *StreamableHTTP) IsOAuthEnabled() bool {
+ return c.oauthHandler != nil
+}
@@ -1,8 +0,0 @@
-package client
-
-import "encoding/json"
-
-type RPCResponse struct {
- Error *string
- Response *json.RawMessage
-}
@@ -19,12 +19,14 @@ type ListPromptsResult struct {
// server.
type GetPromptRequest struct {
Request
- Params struct {
- // The name of the prompt or prompt template.
- Name string `json:"name"`
- // Arguments to use for templating the prompt.
- Arguments map[string]string `json:"arguments,omitempty"`
- } `json:"params"`
+ Params GetPromptParams `json:"params"`
+}
+
+type GetPromptParams struct {
+ // The name of the prompt or prompt template.
+ Name string `json:"name"`
+ // Arguments to use for templating the prompt.
+ Arguments map[string]string `json:"arguments,omitempty"`
}
// GetPromptResult is the server's response to a prompts/get request from the
@@ -50,6 +52,11 @@ type Prompt struct {
Arguments []PromptArgument `json:"arguments,omitempty"`
}
+// GetName returns the name of the prompt.
+func (p Prompt) GetName() string {
+ return p.Name
+}
+
// PromptArgument describes an argument that a prompt template can accept.
// When a prompt includes arguments, clients must provide values for all
// required arguments when making a prompts/get request.
@@ -78,7 +85,7 @@ const (
// resources from the MCP server.
type PromptMessage struct {
Role Role `json:"role"`
- Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource
+ Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource
}
// PromptListChangedNotification is an optional notification from the server
@@ -43,10 +43,7 @@ func WithMIMEType(mimeType string) ResourceOption {
func WithAnnotations(audience []Role, priority float64) ResourceOption {
return func(r *Resource) {
if r.Annotations == nil {
- r.Annotations = &struct {
- Audience []Role `json:"audience,omitempty"`
- Priority float64 `json:"priority,omitempty"`
- }{}
+ r.Annotations = &Annotations{}
}
r.Annotations.Audience = audience
r.Annotations.Priority = priority
@@ -94,10 +91,7 @@ func WithTemplateMIMEType(mimeType string) ResourceTemplateOption {
func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption {
return func(t *ResourceTemplate) {
if t.Annotations == nil {
- t.Annotations = &struct {
- Audience []Role `json:"audience,omitempty"`
- Priority float64 `json:"priority,omitempty"`
- }{}
+ t.Annotations = &Annotations{}
}
t.Annotations.Audience = audience
t.Annotations.Priority = priority
@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"fmt"
+ "reflect"
+ "strconv"
)
var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both")
@@ -33,7 +35,7 @@ type ListToolsResult struct {
// should be reported as an MCP error response.
type CallToolResult struct {
Result
- Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource
+ Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource
// Whether the tool call ended in an error.
//
// If not set, this is assumed to be false (the call was successful).
@@ -43,19 +45,420 @@ type CallToolResult struct {
// CallToolRequest is used by the client to invoke a tool provided by the server.
type CallToolRequest struct {
Request
- Params struct {
- Name string `json:"name"`
- Arguments map[string]interface{} `json:"arguments,omitempty"`
- Meta *struct {
- // If specified, the caller is requesting out-of-band progress
- // notifications for this request (as represented by
- // notifications/progress). The value of this parameter is an
- // opaque token that will be attached to any subsequent
- // notifications. The receiver is not obligated to provide these
- // notifications.
- ProgressToken ProgressToken `json:"progressToken,omitempty"`
- } `json:"_meta,omitempty"`
- } `json:"params"`
+ Params CallToolParams `json:"params"`
+}
+
+type CallToolParams struct {
+ Name string `json:"name"`
+ Arguments any `json:"arguments,omitempty"`
+ Meta *Meta `json:"_meta,omitempty"`
+}
+
+// GetArguments returns the Arguments as map[string]any for backward compatibility
+// If Arguments is not a map, it returns an empty map
+func (r CallToolRequest) GetArguments() map[string]any {
+ if args, ok := r.Params.Arguments.(map[string]any); ok {
+ return args
+ }
+ return nil
+}
+
+// GetRawArguments returns the Arguments as-is without type conversion
+// This allows users to access the raw arguments in any format
+func (r CallToolRequest) GetRawArguments() any {
+ return r.Params.Arguments
+}
+
+// BindArguments unmarshals the Arguments into the provided struct
+// This is useful for working with strongly-typed arguments
+func (r CallToolRequest) BindArguments(target any) error {
+ if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr {
+ return fmt.Errorf("target must be a non-nil pointer")
+ }
+
+ // Fast-path: already raw JSON
+ if raw, ok := r.Params.Arguments.(json.RawMessage); ok {
+ return json.Unmarshal(raw, target)
+ }
+
+ data, err := json.Marshal(r.Params.Arguments)
+ if err != nil {
+ return fmt.Errorf("failed to marshal arguments: %w", err)
+ }
+
+ return json.Unmarshal(data, target)
+}
+
+// GetString returns a string argument by key, or the default value if not found
+func (r CallToolRequest) GetString(key string, defaultValue string) string {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ if str, ok := val.(string); ok {
+ return str
+ }
+ }
+ return defaultValue
+}
+
+// RequireString returns a string argument by key, or an error if not found or not a string
+func (r CallToolRequest) RequireString(key string) (string, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ if str, ok := val.(string); ok {
+ return str, nil
+ }
+ return "", fmt.Errorf("argument %q is not a string", key)
+ }
+ return "", fmt.Errorf("required argument %q not found", key)
+}
+
+// GetInt returns an int argument by key, or the default value if not found
+func (r CallToolRequest) GetInt(key string, defaultValue int) int {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case int:
+ return v
+ case float64:
+ return int(v)
+ case string:
+ if i, err := strconv.Atoi(v); err == nil {
+ return i
+ }
+ }
+ }
+ return defaultValue
+}
+
+// RequireInt returns an int argument by key, or an error if not found or not convertible to int
+func (r CallToolRequest) RequireInt(key string) (int, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case int:
+ return v, nil
+ case float64:
+ return int(v), nil
+ case string:
+ if i, err := strconv.Atoi(v); err == nil {
+ return i, nil
+ }
+ return 0, fmt.Errorf("argument %q cannot be converted to int", key)
+ default:
+ return 0, fmt.Errorf("argument %q is not an int", key)
+ }
+ }
+ return 0, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetFloat returns a float64 argument by key, or the default value if not found
+func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case float64:
+ return v
+ case int:
+ return float64(v)
+ case string:
+ if f, err := strconv.ParseFloat(v, 64); err == nil {
+ return f
+ }
+ }
+ }
+ return defaultValue
+}
+
+// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64
+func (r CallToolRequest) RequireFloat(key string) (float64, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case float64:
+ return v, nil
+ case int:
+ return float64(v), nil
+ case string:
+ if f, err := strconv.ParseFloat(v, 64); err == nil {
+ return f, nil
+ }
+ return 0, fmt.Errorf("argument %q cannot be converted to float64", key)
+ default:
+ return 0, fmt.Errorf("argument %q is not a float64", key)
+ }
+ }
+ return 0, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetBool returns a bool argument by key, or the default value if not found
+func (r CallToolRequest) GetBool(key string, defaultValue bool) bool {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case bool:
+ return v
+ case string:
+ if b, err := strconv.ParseBool(v); err == nil {
+ return b
+ }
+ case int:
+ return v != 0
+ case float64:
+ return v != 0
+ }
+ }
+ return defaultValue
+}
+
+// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool
+func (r CallToolRequest) RequireBool(key string) (bool, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case bool:
+ return v, nil
+ case string:
+ if b, err := strconv.ParseBool(v); err == nil {
+ return b, nil
+ }
+ return false, fmt.Errorf("argument %q cannot be converted to bool", key)
+ case int:
+ return v != 0, nil
+ case float64:
+ return v != 0, nil
+ default:
+ return false, fmt.Errorf("argument %q is not a bool", key)
+ }
+ }
+ return false, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetStringSlice returns a string slice argument by key, or the default value if not found
+func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []string:
+ return v
+ case []any:
+ result := make([]string, 0, len(v))
+ for _, item := range v {
+ if str, ok := item.(string); ok {
+ result = append(result, str)
+ }
+ }
+ return result
+ }
+ }
+ return defaultValue
+}
+
+// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice
+func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []string:
+ return v, nil
+ case []any:
+ result := make([]string, 0, len(v))
+ for i, item := range v {
+ if str, ok := item.(string); ok {
+ result = append(result, str)
+ } else {
+ return nil, fmt.Errorf("item %d in argument %q is not a string", i, key)
+ }
+ }
+ return result, nil
+ default:
+ return nil, fmt.Errorf("argument %q is not a string slice", key)
+ }
+ }
+ return nil, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetIntSlice returns an int slice argument by key, or the default value if not found
+func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []int:
+ return v
+ case []any:
+ result := make([]int, 0, len(v))
+ for _, item := range v {
+ switch num := item.(type) {
+ case int:
+ result = append(result, num)
+ case float64:
+ result = append(result, int(num))
+ case string:
+ if i, err := strconv.Atoi(num); err == nil {
+ result = append(result, i)
+ }
+ }
+ }
+ return result
+ }
+ }
+ return defaultValue
+}
+
+// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice
+func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []int:
+ return v, nil
+ case []any:
+ result := make([]int, 0, len(v))
+ for i, item := range v {
+ switch num := item.(type) {
+ case int:
+ result = append(result, num)
+ case float64:
+ result = append(result, int(num))
+ case string:
+ if i, err := strconv.Atoi(num); err == nil {
+ result = append(result, i)
+ } else {
+ return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key)
+ }
+ default:
+ return nil, fmt.Errorf("item %d in argument %q is not an int", i, key)
+ }
+ }
+ return result, nil
+ default:
+ return nil, fmt.Errorf("argument %q is not an int slice", key)
+ }
+ }
+ return nil, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetFloatSlice returns a float64 slice argument by key, or the default value if not found
+func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []float64:
+ return v
+ case []any:
+ result := make([]float64, 0, len(v))
+ for _, item := range v {
+ switch num := item.(type) {
+ case float64:
+ result = append(result, num)
+ case int:
+ result = append(result, float64(num))
+ case string:
+ if f, err := strconv.ParseFloat(num, 64); err == nil {
+ result = append(result, f)
+ }
+ }
+ }
+ return result
+ }
+ }
+ return defaultValue
+}
+
+// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice
+func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []float64:
+ return v, nil
+ case []any:
+ result := make([]float64, 0, len(v))
+ for i, item := range v {
+ switch num := item.(type) {
+ case float64:
+ result = append(result, num)
+ case int:
+ result = append(result, float64(num))
+ case string:
+ if f, err := strconv.ParseFloat(num, 64); err == nil {
+ result = append(result, f)
+ } else {
+ return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key)
+ }
+ default:
+ return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key)
+ }
+ }
+ return result, nil
+ default:
+ return nil, fmt.Errorf("argument %q is not a float64 slice", key)
+ }
+ }
+ return nil, fmt.Errorf("required argument %q not found", key)
+}
+
+// GetBoolSlice returns a bool slice argument by key, or the default value if not found
+func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []bool:
+ return v
+ case []any:
+ result := make([]bool, 0, len(v))
+ for _, item := range v {
+ switch b := item.(type) {
+ case bool:
+ result = append(result, b)
+ case string:
+ if parsed, err := strconv.ParseBool(b); err == nil {
+ result = append(result, parsed)
+ }
+ case int:
+ result = append(result, b != 0)
+ case float64:
+ result = append(result, b != 0)
+ }
+ }
+ return result
+ }
+ }
+ return defaultValue
+}
+
+// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice
+func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) {
+ args := r.GetArguments()
+ if val, ok := args[key]; ok {
+ switch v := val.(type) {
+ case []bool:
+ return v, nil
+ case []any:
+ result := make([]bool, 0, len(v))
+ for i, item := range v {
+ switch b := item.(type) {
+ case bool:
+ result = append(result, b)
+ case string:
+ if parsed, err := strconv.ParseBool(b); err == nil {
+ result = append(result, parsed)
+ } else {
+ return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key)
+ }
+ case int:
+ result = append(result, b != 0)
+ case float64:
+ result = append(result, b != 0)
+ default:
+ return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key)
+ }
+ }
+ return result, nil
+ default:
+ return nil, fmt.Errorf("argument %q is not a bool slice", key)
+ }
+ }
+ return nil, fmt.Errorf("required argument %q not found", key)
}
// ToolListChangedNotification is an optional notification from the server to
@@ -75,13 +478,20 @@ type Tool struct {
InputSchema ToolInputSchema `json:"inputSchema"`
// Alternative to InputSchema - allows arbitrary JSON Schema to be provided
RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling
+ // Optional properties describing tool behavior
+ Annotations ToolAnnotation `json:"annotations"`
+}
+
+// GetName returns the name of the tool.
+func (t Tool) GetName() string {
+ return t.Name
}
// MarshalJSON implements the json.Marshaler interface for Tool.
// It handles marshaling either InputSchema or RawInputSchema based on which is set.
func (t Tool) MarshalJSON() ([]byte, error) {
// Create a map to build the JSON structure
- m := make(map[string]interface{}, 3)
+ m := make(map[string]any, 3)
// Add the name and description
m["name"] = t.Name
@@ -100,13 +510,45 @@ func (t Tool) MarshalJSON() ([]byte, error) {
m["inputSchema"] = t.InputSchema
}
+ m["annotations"] = t.Annotations
+
return json.Marshal(m)
}
type ToolInputSchema struct {
- Type string `json:"type"`
- Properties map[string]interface{} `json:"properties"`
- Required []string `json:"required,omitempty"`
+ Type string `json:"type"`
+ Properties map[string]any `json:"properties,omitempty"`
+ Required []string `json:"required,omitempty"`
+}
+
+// MarshalJSON implements the json.Marshaler interface for ToolInputSchema.
+func (tis ToolInputSchema) MarshalJSON() ([]byte, error) {
+ m := make(map[string]any)
+ m["type"] = tis.Type
+
+ // Marshal Properties to '{}' rather than `nil` when its length equals zero
+ if tis.Properties != nil {
+ m["properties"] = tis.Properties
+ }
+
+ if len(tis.Required) > 0 {
+ m["required"] = tis.Required
+ }
+
+ return json.Marshal(m)
+}
+
+type ToolAnnotation struct {
+ // Human-readable title for the tool
+ Title string `json:"title,omitempty"`
+ // If true, the tool does not modify its environment
+ ReadOnlyHint *bool `json:"readOnlyHint,omitempty"`
+ // If true, the tool may perform destructive updates
+ DestructiveHint *bool `json:"destructiveHint,omitempty"`
+ // If true, repeated calls with same args have no additional effect
+ IdempotentHint *bool `json:"idempotentHint,omitempty"`
+ // If true, tool interacts with external entities
+ OpenWorldHint *bool `json:"openWorldHint,omitempty"`
}
// ToolOption is a function that configures a Tool.
@@ -115,7 +557,7 @@ type ToolOption func(*Tool)
// PropertyOption is a function that configures a property in a Tool's input schema.
// It allows for flexible configuration of JSON Schema properties using the functional options pattern.
-type PropertyOption func(map[string]interface{})
+type PropertyOption func(map[string]any)
//
// Core Tool Functions
@@ -129,9 +571,16 @@ func NewTool(name string, opts ...ToolOption) Tool {
Name: name,
InputSchema: ToolInputSchema{
Type: "object",
- Properties: make(map[string]interface{}),
+ Properties: make(map[string]any),
Required: nil, // Will be omitted from JSON if empty
},
+ Annotations: ToolAnnotation{
+ Title: "",
+ ReadOnlyHint: ToBoolPtr(false),
+ DestructiveHint: ToBoolPtr(true),
+ IdempotentHint: ToBoolPtr(false),
+ OpenWorldHint: ToBoolPtr(true),
+ },
}
for _, opt := range opts {
@@ -166,6 +615,53 @@ func WithDescription(description string) ToolOption {
}
}
+// WithToolAnnotation adds optional hints about the Tool.
+func WithToolAnnotation(annotation ToolAnnotation) ToolOption {
+ return func(t *Tool) {
+ t.Annotations = annotation
+ }
+}
+
+// WithTitleAnnotation sets the Title field of the Tool's Annotations.
+// It provides a human-readable title for the tool.
+func WithTitleAnnotation(title string) ToolOption {
+ return func(t *Tool) {
+ t.Annotations.Title = title
+ }
+}
+
+// WithReadOnlyHintAnnotation sets the ReadOnlyHint field of the Tool's Annotations.
+// If true, it indicates the tool does not modify its environment.
+func WithReadOnlyHintAnnotation(value bool) ToolOption {
+ return func(t *Tool) {
+ t.Annotations.ReadOnlyHint = &value
+ }
+}
+
+// WithDestructiveHintAnnotation sets the DestructiveHint field of the Tool's Annotations.
+// If true, it indicates the tool may perform destructive updates.
+func WithDestructiveHintAnnotation(value bool) ToolOption {
+ return func(t *Tool) {
+ t.Annotations.DestructiveHint = &value
+ }
+}
+
+// WithIdempotentHintAnnotation sets the IdempotentHint field of the Tool's Annotations.
+// If true, it indicates repeated calls with the same arguments have no additional effect.
+func WithIdempotentHintAnnotation(value bool) ToolOption {
+ return func(t *Tool) {
+ t.Annotations.IdempotentHint = &value
+ }
+}
+
+// WithOpenWorldHintAnnotation sets the OpenWorldHint field of the Tool's Annotations.
+// If true, it indicates the tool interacts with external entities.
+func WithOpenWorldHintAnnotation(value bool) ToolOption {
+ return func(t *Tool) {
+ t.Annotations.OpenWorldHint = &value
+ }
+}
+
//
// Common Property Options
//
@@ -173,7 +669,7 @@ func WithDescription(description string) ToolOption {
// Description adds a description to a property in the JSON Schema.
// The description should explain the purpose and expected values of the property.
func Description(desc string) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["description"] = desc
}
}
@@ -181,7 +677,7 @@ func Description(desc string) PropertyOption {
// Required marks a property as required in the tool's input schema.
// Required properties must be provided when using the tool.
func Required() PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["required"] = true
}
}
@@ -189,7 +685,7 @@ func Required() PropertyOption {
// Title adds a display-friendly title to a property in the JSON Schema.
// This title can be used by UI components to show a more readable property name.
func Title(title string) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["title"] = title
}
}
@@ -201,7 +697,7 @@ func Title(title string) PropertyOption {
// DefaultString sets the default value for a string property.
// This value will be used if the property is not explicitly provided.
func DefaultString(value string) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["default"] = value
}
}
@@ -209,7 +705,7 @@ func DefaultString(value string) PropertyOption {
// Enum specifies a list of allowed values for a string property.
// The property value must be one of the specified enum values.
func Enum(values ...string) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["enum"] = values
}
}
@@ -217,7 +713,7 @@ func Enum(values ...string) PropertyOption {
// MaxLength sets the maximum length for a string property.
// The string value must not exceed this length.
func MaxLength(max int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["maxLength"] = max
}
}
@@ -225,7 +721,7 @@ func MaxLength(max int) PropertyOption {
// MinLength sets the minimum length for a string property.
// The string value must be at least this length.
func MinLength(min int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["minLength"] = min
}
}
@@ -233,7 +729,7 @@ func MinLength(min int) PropertyOption {
// Pattern sets a regex pattern that a string property must match.
// The string value must conform to the specified regular expression.
func Pattern(pattern string) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["pattern"] = pattern
}
}
@@ -245,7 +741,7 @@ func Pattern(pattern string) PropertyOption {
// DefaultNumber sets the default value for a number property.
// This value will be used if the property is not explicitly provided.
func DefaultNumber(value float64) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["default"] = value
}
}
@@ -253,7 +749,7 @@ func DefaultNumber(value float64) PropertyOption {
// Max sets the maximum value for a number property.
// The number value must not exceed this maximum.
func Max(max float64) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["maximum"] = max
}
}
@@ -261,7 +757,7 @@ func Max(max float64) PropertyOption {
// Min sets the minimum value for a number property.
// The number value must not be less than this minimum.
func Min(min float64) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["minimum"] = min
}
}
@@ -269,7 +765,7 @@ func Min(min float64) PropertyOption {
// MultipleOf specifies that a number must be a multiple of the given value.
// The number value must be divisible by this value.
func MultipleOf(value float64) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["multipleOf"] = value
}
}
@@ -281,7 +777,19 @@ func MultipleOf(value float64) PropertyOption {
// DefaultBool sets the default value for a boolean property.
// This value will be used if the property is not explicitly provided.
func DefaultBool(value bool) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
+ schema["default"] = value
+ }
+}
+
+//
+// Array Property Options
+//
+
+// DefaultArray sets the default value for an array property.
+// This value will be used if the property is not explicitly provided.
+func DefaultArray[T any](value []T) PropertyOption {
+ return func(schema map[string]any) {
schema["default"] = value
}
}
@@ -294,7 +802,7 @@ func DefaultBool(value bool) PropertyOption {
// It accepts property options to configure the boolean property's behavior and constraints.
func WithBoolean(name string, opts ...PropertyOption) ToolOption {
return func(t *Tool) {
- schema := map[string]interface{}{
+ schema := map[string]any{
"type": "boolean",
}
@@ -305,11 +813,7 @@ func WithBoolean(name string, opts ...PropertyOption) ToolOption {
// Remove required from property schema and add to InputSchema.required
if required, ok := schema["required"].(bool); ok && required {
delete(schema, "required")
- if t.InputSchema.Required == nil {
- t.InputSchema.Required = []string{name}
- } else {
- t.InputSchema.Required = append(t.InputSchema.Required, name)
- }
+ t.InputSchema.Required = append(t.InputSchema.Required, name)
}
t.InputSchema.Properties[name] = schema
@@ -320,7 +824,7 @@ func WithBoolean(name string, opts ...PropertyOption) ToolOption {
// It accepts property options to configure the number property's behavior and constraints.
func WithNumber(name string, opts ...PropertyOption) ToolOption {
return func(t *Tool) {
- schema := map[string]interface{}{
+ schema := map[string]any{
"type": "number",
}
@@ -331,11 +835,7 @@ func WithNumber(name string, opts ...PropertyOption) ToolOption {
// Remove required from property schema and add to InputSchema.required
if required, ok := schema["required"].(bool); ok && required {
delete(schema, "required")
- if t.InputSchema.Required == nil {
- t.InputSchema.Required = []string{name}
- } else {
- t.InputSchema.Required = append(t.InputSchema.Required, name)
- }
+ t.InputSchema.Required = append(t.InputSchema.Required, name)
}
t.InputSchema.Properties[name] = schema
@@ -346,7 +846,7 @@ func WithNumber(name string, opts ...PropertyOption) ToolOption {
// It accepts property options to configure the string property's behavior and constraints.
func WithString(name string, opts ...PropertyOption) ToolOption {
return func(t *Tool) {
- schema := map[string]interface{}{
+ schema := map[string]any{
"type": "string",
}
@@ -357,11 +857,7 @@ func WithString(name string, opts ...PropertyOption) ToolOption {
// Remove required from property schema and add to InputSchema.required
if required, ok := schema["required"].(bool); ok && required {
delete(schema, "required")
- if t.InputSchema.Required == nil {
- t.InputSchema.Required = []string{name}
- } else {
- t.InputSchema.Required = append(t.InputSchema.Required, name)
- }
+ t.InputSchema.Required = append(t.InputSchema.Required, name)
}
t.InputSchema.Properties[name] = schema
@@ -372,9 +868,9 @@ func WithString(name string, opts ...PropertyOption) ToolOption {
// It accepts property options to configure the object property's behavior and constraints.
func WithObject(name string, opts ...PropertyOption) ToolOption {
return func(t *Tool) {
- schema := map[string]interface{}{
+ schema := map[string]any{
"type": "object",
- "properties": map[string]interface{}{},
+ "properties": map[string]any{},
}
for _, opt := range opts {
@@ -384,11 +880,7 @@ func WithObject(name string, opts ...PropertyOption) ToolOption {
// Remove required from property schema and add to InputSchema.required
if required, ok := schema["required"].(bool); ok && required {
delete(schema, "required")
- if t.InputSchema.Required == nil {
- t.InputSchema.Required = []string{name}
- } else {
- t.InputSchema.Required = append(t.InputSchema.Required, name)
- }
+ t.InputSchema.Required = append(t.InputSchema.Required, name)
}
t.InputSchema.Properties[name] = schema
@@ -399,7 +891,7 @@ func WithObject(name string, opts ...PropertyOption) ToolOption {
// It accepts property options to configure the array property's behavior and constraints.
func WithArray(name string, opts ...PropertyOption) ToolOption {
return func(t *Tool) {
- schema := map[string]interface{}{
+ schema := map[string]any{
"type": "array",
}
@@ -410,11 +902,7 @@ func WithArray(name string, opts ...PropertyOption) ToolOption {
// Remove required from property schema and add to InputSchema.required
if required, ok := schema["required"].(bool); ok && required {
delete(schema, "required")
- if t.InputSchema.Required == nil {
- t.InputSchema.Required = []string{name}
- } else {
- t.InputSchema.Required = append(t.InputSchema.Required, name)
- }
+ t.InputSchema.Required = append(t.InputSchema.Required, name)
}
t.InputSchema.Properties[name] = schema
@@ -422,65 +910,65 @@ func WithArray(name string, opts ...PropertyOption) ToolOption {
}
// Properties defines the properties for an object schema
-func Properties(props map[string]interface{}) PropertyOption {
- return func(schema map[string]interface{}) {
+func Properties(props map[string]any) PropertyOption {
+ return func(schema map[string]any) {
schema["properties"] = props
}
}
// AdditionalProperties specifies whether additional properties are allowed in the object
// or defines a schema for additional properties
-func AdditionalProperties(schema interface{}) PropertyOption {
- return func(schemaMap map[string]interface{}) {
+func AdditionalProperties(schema any) PropertyOption {
+ return func(schemaMap map[string]any) {
schemaMap["additionalProperties"] = schema
}
}
// MinProperties sets the minimum number of properties for an object
func MinProperties(min int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["minProperties"] = min
}
}
// MaxProperties sets the maximum number of properties for an object
func MaxProperties(max int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["maxProperties"] = max
}
}
// PropertyNames defines a schema for property names in an object
-func PropertyNames(schema map[string]interface{}) PropertyOption {
- return func(schemaMap map[string]interface{}) {
+func PropertyNames(schema map[string]any) PropertyOption {
+ return func(schemaMap map[string]any) {
schemaMap["propertyNames"] = schema
}
}
// Items defines the schema for array items
-func Items(schema interface{}) PropertyOption {
- return func(schemaMap map[string]interface{}) {
+func Items(schema any) PropertyOption {
+ return func(schemaMap map[string]any) {
schemaMap["items"] = schema
}
}
// MinItems sets the minimum number of items for an array
func MinItems(min int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["minItems"] = min
}
}
// MaxItems sets the maximum number of items for an array
func MaxItems(max int) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["maxItems"] = max
}
}
// UniqueItems specifies whether array items must be unique
func UniqueItems(unique bool) PropertyOption {
- return func(schema map[string]interface{}) {
+ return func(schema map[string]any) {
schema["uniqueItems"] = unique
}
}
@@ -0,0 +1,20 @@
+package mcp
+
+import (
+ "context"
+ "fmt"
+)
+
+// TypedToolHandlerFunc is a function that handles a tool call with typed arguments
+type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error)
+
+// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct
+func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) {
+ return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) {
+ var args T
+ if err := request.BindArguments(&args); err != nil {
+ return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil
+ }
+ return handler(ctx, request, args)
+ }
+}
@@ -1,9 +1,12 @@
-// Package mcp defines the core types and interfaces for the Model Control Protocol (MCP).
+// Package mcp defines the core types and interfaces for the Model Context Protocol (MCP).
// MCP is a protocol for communication between LLM-powered applications and their supporting services.
package mcp
import (
"encoding/json"
+ "fmt"
+ "maps"
+ "strconv"
"github.com/yosida95/uritemplate/v3"
)
@@ -11,41 +14,59 @@ import (
type MCPMethod string
const (
- // Initiates connection and negotiates protocol capabilities.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization
+ // MethodInitialize initiates connection and negotiates protocol capabilities.
+ // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization
MethodInitialize MCPMethod = "initialize"
- // Verifies connection liveness between client and server.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/
+ // MethodPing verifies connection liveness between client and server.
+ // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/
MethodPing MCPMethod = "ping"
- // Lists all available server resources.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/
+ // MethodResourcesList lists all available server resources.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/
MethodResourcesList MCPMethod = "resources/list"
- // Provides URI templates for constructing resource URIs.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/
+ // MethodResourcesTemplatesList provides URI templates for constructing resource URIs.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/
MethodResourcesTemplatesList MCPMethod = "resources/templates/list"
- // Retrieves content of a specific resource by URI.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/
+ // MethodResourcesRead retrieves content of a specific resource by URI.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/
MethodResourcesRead MCPMethod = "resources/read"
- // Lists all available prompt templates.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/
+ // MethodPromptsList lists all available prompt templates.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/
MethodPromptsList MCPMethod = "prompts/list"
- // Retrieves a specific prompt template with filled parameters.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/
+ // MethodPromptsGet retrieves a specific prompt template with filled parameters.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/
MethodPromptsGet MCPMethod = "prompts/get"
- // Lists all available executable tools.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/
+ // MethodToolsList lists all available executable tools.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/
MethodToolsList MCPMethod = "tools/list"
- // Invokes a specific tool with provided parameters.
- // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/
+ // MethodToolsCall invokes a specific tool with provided parameters.
+ // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/
MethodToolsCall MCPMethod = "tools/call"
+
+ // MethodSetLogLevel configures the minimum log level for client
+ // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging
+ MethodSetLogLevel MCPMethod = "logging/setLevel"
+
+ // MethodNotificationResourcesListChanged notifies when the list of available resources changes.
+ // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification
+ MethodNotificationResourcesListChanged = "notifications/resources/list_changed"
+
+ MethodNotificationResourceUpdated = "notifications/resources/updated"
+
+ // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes.
+ // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification
+ MethodNotificationPromptsListChanged = "notifications/prompts/list_changed"
+
+ // MethodNotificationToolsListChanged notifies when the list of available tools changes.
+ // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/
+ MethodNotificationToolsListChanged = "notifications/tools/list_changed"
)
type URITemplate struct {
@@ -53,7 +74,7 @@ type URITemplate struct {
}
func (t *URITemplate) MarshalJSON() ([]byte, error) {
- return json.Marshal(t.Template.Raw())
+ return json.Marshal(t.Raw())
}
func (t *URITemplate) UnmarshalJSON(data []byte) error {
@@ -72,36 +93,73 @@ func (t *URITemplate) UnmarshalJSON(data []byte) error {
/* JSON-RPC types */
// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError
-type JSONRPCMessage interface{}
+type JSONRPCMessage any
// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol.
-const LATEST_PROTOCOL_VERSION = "2024-11-05"
+const LATEST_PROTOCOL_VERSION = "2025-03-26"
+
+// ValidProtocolVersions lists all known valid MCP protocol versions.
+var ValidProtocolVersions = []string{
+ "2024-11-05",
+ LATEST_PROTOCOL_VERSION,
+}
// JSONRPC_VERSION is the version of JSON-RPC used by MCP.
const JSONRPC_VERSION = "2.0"
// ProgressToken is used to associate progress notifications with the original request.
-type ProgressToken interface{}
+type ProgressToken any
// Cursor is an opaque token used to represent a cursor for pagination.
type Cursor string
+// Meta is metadata attached to a request's parameters. This can include fields
+// formally defined by the protocol or other arbitrary data.
+type Meta struct {
+ // If specified, the caller is requesting out-of-band progress
+ // notifications for this request (as represented by
+ // notifications/progress). The value of this parameter is an
+ // opaque token that will be attached to any subsequent
+ // notifications. The receiver is not obligated to provide these
+ // notifications.
+ ProgressToken ProgressToken
+
+ // AdditionalFields are any fields present in the Meta that are not
+ // otherwise defined in the protocol.
+ AdditionalFields map[string]any
+}
+
+func (m *Meta) MarshalJSON() ([]byte, error) {
+ raw := make(map[string]any)
+ if m.ProgressToken != nil {
+ raw["progressToken"] = m.ProgressToken
+ }
+ maps.Copy(raw, m.AdditionalFields)
+
+ return json.Marshal(raw)
+}
+
+func (m *Meta) UnmarshalJSON(data []byte) error {
+ raw := make(map[string]any)
+ if err := json.Unmarshal(data, &raw); err != nil {
+ return err
+ }
+ m.ProgressToken = raw["progressToken"]
+ delete(raw, "progressToken")
+ m.AdditionalFields = raw
+ return nil
+}
+
type Request struct {
- Method string `json:"method"`
- Params struct {
- Meta *struct {
- // If specified, the caller is requesting out-of-band progress
- // notifications for this request (as represented by
- // notifications/progress). The value of this parameter is an
- // opaque token that will be attached to any subsequent
- // notifications. The receiver is not obligated to provide these
- // notifications.
- ProgressToken ProgressToken `json:"progressToken,omitempty"`
- } `json:"_meta,omitempty"`
- } `json:"params,omitempty"`
-}
-
-type Params map[string]interface{}
+ Method string `json:"method"`
+ Params RequestParams `json:"params,omitempty"`
+}
+
+type RequestParams struct {
+ Meta *Meta `json:"_meta,omitempty"`
+}
+
+type Params map[string]any
type Notification struct {
Method string `json:"method"`
@@ -111,16 +169,16 @@ type Notification struct {
type NotificationParams struct {
// This parameter name is reserved by MCP to allow clients and
// servers to attach additional metadata to their notifications.
- Meta map[string]interface{} `json:"_meta,omitempty"`
+ Meta map[string]any `json:"_meta,omitempty"`
// Additional fields can be added to this map
- AdditionalFields map[string]interface{} `json:"-"`
+ AdditionalFields map[string]any `json:"-"`
}
// MarshalJSON implements custom JSON marshaling
func (p NotificationParams) MarshalJSON() ([]byte, error) {
// Create a map to hold all fields
- m := make(map[string]interface{})
+ m := make(map[string]any)
// Add Meta if it exists
if p.Meta != nil {
@@ -141,24 +199,24 @@ func (p NotificationParams) MarshalJSON() ([]byte, error) {
// UnmarshalJSON implements custom JSON unmarshaling
func (p *NotificationParams) UnmarshalJSON(data []byte) error {
// Create a map to hold all fields
- var m map[string]interface{}
+ var m map[string]any
if err := json.Unmarshal(data, &m); err != nil {
return err
}
// Initialize maps if they're nil
if p.Meta == nil {
- p.Meta = make(map[string]interface{})
+ p.Meta = make(map[string]any)
}
if p.AdditionalFields == nil {
- p.AdditionalFields = make(map[string]interface{})
+ p.AdditionalFields = make(map[string]any)
}
// Process all fields
for k, v := range m {
if k == "_meta" {
// Handle Meta field
- if meta, ok := v.(map[string]interface{}); ok {
+ if meta, ok := v.(map[string]any); ok {
p.Meta = meta
}
} else {
@@ -173,18 +231,86 @@ func (p *NotificationParams) UnmarshalJSON(data []byte) error {
type Result struct {
// This result property is reserved by the protocol to allow clients and
// servers to attach additional metadata to their responses.
- Meta map[string]interface{} `json:"_meta,omitempty"`
+ Meta map[string]any `json:"_meta,omitempty"`
}
// RequestId is a uniquely identifying ID for a request in JSON-RPC.
// It can be any JSON-serializable value, typically a number or string.
-type RequestId interface{}
+type RequestId struct {
+ value any
+}
+
+// NewRequestId creates a new RequestId with the given value
+func NewRequestId(value any) RequestId {
+ return RequestId{value: value}
+}
+
+// Value returns the underlying value of the RequestId
+func (r RequestId) Value() any {
+ return r.value
+}
+
+// String returns a string representation of the RequestId
+func (r RequestId) String() string {
+ switch v := r.value.(type) {
+ case string:
+ return "string:" + v
+ case int64:
+ return "int64:" + strconv.FormatInt(v, 10)
+ case float64:
+ if v == float64(int64(v)) {
+ return "int64:" + strconv.FormatInt(int64(v), 10)
+ }
+ return "float64:" + strconv.FormatFloat(v, 'f', -1, 64)
+ case nil:
+ return "<nil>"
+ default:
+ return "unknown:" + fmt.Sprintf("%v", v)
+ }
+}
+
+// IsNil returns true if the RequestId is nil
+func (r RequestId) IsNil() bool {
+ return r.value == nil
+}
+
+func (r RequestId) MarshalJSON() ([]byte, error) {
+ return json.Marshal(r.value)
+}
+
+func (r *RequestId) UnmarshalJSON(data []byte) error {
+
+ if string(data) == "null" {
+ r.value = nil
+ return nil
+ }
+
+ // Try unmarshaling as string first
+ var s string
+ if err := json.Unmarshal(data, &s); err == nil {
+ r.value = s
+ return nil
+ }
+
+ // JSON numbers are unmarshaled as float64 in Go
+ var f float64
+ if err := json.Unmarshal(data, &f); err == nil {
+ if f == float64(int64(f)) {
+ r.value = int64(f)
+ } else {
+ r.value = f
+ }
+ return nil
+ }
+
+ return fmt.Errorf("invalid request id: %s", string(data))
+}
// JSONRPCRequest represents a request that expects a response.
type JSONRPCRequest struct {
- JSONRPC string `json:"jsonrpc"`
- ID RequestId `json:"id"`
- Params interface{} `json:"params,omitempty"`
+ JSONRPC string `json:"jsonrpc"`
+ ID RequestId `json:"id"`
+ Params any `json:"params,omitempty"`
Request
}
@@ -196,9 +322,9 @@ type JSONRPCNotification struct {
// JSONRPCResponse represents a successful (non-error) response to a request.
type JSONRPCResponse struct {
- JSONRPC string `json:"jsonrpc"`
- ID RequestId `json:"id"`
- Result interface{} `json:"result"`
+ JSONRPC string `json:"jsonrpc"`
+ ID RequestId `json:"id"`
+ Result any `json:"result"`
}
// JSONRPCError represents a non-successful (error) response to a request.
@@ -213,7 +339,7 @@ type JSONRPCError struct {
Message string `json:"message"`
// Additional information about the error. The value of this member
// is defined by the sender (e.g. detailed error information, nested errors etc.).
- Data interface{} `json:"data,omitempty"`
+ Data any `json:"data,omitempty"`
} `json:"error"`
}
@@ -226,6 +352,11 @@ const (
INTERNAL_ERROR = -32603
)
+// MCP error codes
+const (
+ RESOURCE_NOT_FOUND = -32002
+)
+
/* Empty result */
// EmptyResult represents a response that indicates success but carries no data.
@@ -246,17 +377,19 @@ type EmptyResult Result
// A client MUST NOT attempt to cancel its `initialize` request.
type CancelledNotification struct {
Notification
- Params struct {
- // The ID of the request to cancel.
- //
- // This MUST correspond to the ID of a request previously issued
- // in the same direction.
- RequestId RequestId `json:"requestId"`
+ Params CancelledNotificationParams `json:"params"`
+}
+
+type CancelledNotificationParams struct {
+ // The ID of the request to cancel.
+ //
+ // This MUST correspond to the ID of a request previously issued
+ // in the same direction.
+ RequestId RequestId `json:"requestId"`
- // An optional string describing the reason for the cancellation. This MAY
- // be logged or presented to the user.
- Reason string `json:"reason,omitempty"`
- } `json:"params"`
+ // An optional string describing the reason for the cancellation. This MAY
+ // be logged or presented to the user.
+ Reason string `json:"reason,omitempty"`
}
/* Initialization */
@@ -265,13 +398,15 @@ type CancelledNotification struct {
// connects, asking it to begin initialization.
type InitializeRequest struct {
Request
- Params struct {
- // The latest version of the Model Context Protocol that the client supports.
- // The client MAY decide to support older versions as well.
- ProtocolVersion string `json:"protocolVersion"`
- Capabilities ClientCapabilities `json:"capabilities"`
- ClientInfo Implementation `json:"clientInfo"`
- } `json:"params"`
+ Params InitializeParams `json:"params"`
+}
+
+type InitializeParams struct {
+ // The latest version of the Model Context Protocol that the client supports.
+ // The client MAY decide to support older versions as well.
+ ProtocolVersion string `json:"protocolVersion"`
+ Capabilities ClientCapabilities `json:"capabilities"`
+ ClientInfo Implementation `json:"clientInfo"`
}
// InitializeResult is sent after receiving an initialize request from the
@@ -303,7 +438,7 @@ type InitializedNotification struct {
// client can define its own, additional capabilities.
type ClientCapabilities struct {
// Experimental, non-standard capabilities that the client supports.
- Experimental map[string]interface{} `json:"experimental,omitempty"`
+ Experimental map[string]any `json:"experimental,omitempty"`
// Present if the client supports listing roots.
Roots *struct {
// Whether the client supports notifications for changes to the roots list.
@@ -318,7 +453,7 @@ type ClientCapabilities struct {
// server can define its own, additional capabilities.
type ServerCapabilities struct {
// Experimental, non-standard capabilities that the server supports.
- Experimental map[string]interface{} `json:"experimental,omitempty"`
+ Experimental map[string]any `json:"experimental,omitempty"`
// Present if the server supports sending log messages to the client.
Logging *struct{} `json:"logging,omitempty"`
// Present if the server offers any prompt templates.
@@ -362,27 +497,34 @@ type PingRequest struct {
// receiver of a progress update for a long-running request.
type ProgressNotification struct {
Notification
- Params struct {
- // The progress token which was given in the initial request, used to
- // associate this notification with the request that is proceeding.
- ProgressToken ProgressToken `json:"progressToken"`
- // The progress thus far. This should increase every time progress is made,
- // even if the total is unknown.
- Progress float64 `json:"progress"`
- // Total number of items to process (or total progress required), if known.
- Total float64 `json:"total,omitempty"`
- } `json:"params"`
+ Params ProgressNotificationParams `json:"params"`
+}
+
+type ProgressNotificationParams struct {
+ // The progress token which was given in the initial request, used to
+ // associate this notification with the request that is proceeding.
+ ProgressToken ProgressToken `json:"progressToken"`
+ // The progress thus far. This should increase every time progress is made,
+ // even if the total is unknown.
+ Progress float64 `json:"progress"`
+ // Total number of items to process (or total progress required), if known.
+ Total float64 `json:"total,omitempty"`
+ // Message related to progress. This should provide relevant human-readable
+ // progress information.
+ Message string `json:"message,omitempty"`
}
/* Pagination */
type PaginatedRequest struct {
Request
- Params struct {
- // An opaque token representing the current pagination position.
- // If provided, the server should return results starting after this cursor.
- Cursor Cursor `json:"cursor,omitempty"`
- } `json:"params,omitempty"`
+ Params PaginatedParams `json:"params,omitempty"`
+}
+
+type PaginatedParams struct {
+ // An opaque token representing the current pagination position.
+ // If provided, the server should return results starting after this cursor.
+ Cursor Cursor `json:"cursor,omitempty"`
}
type PaginatedResult struct {
@@ -425,13 +567,15 @@ type ListResourceTemplatesResult struct {
// specific resource URI.
type ReadResourceRequest struct {
Request
- Params struct {
- // The URI of the resource to read. The URI can use any protocol; it is up
- // to the server how to interpret it.
- URI string `json:"uri"`
- // Arguments to pass to the resource handler
- Arguments map[string]interface{} `json:"arguments,omitempty"`
- } `json:"params"`
+ Params ReadResourceParams `json:"params"`
+}
+
+type ReadResourceParams struct {
+ // The URI of the resource to read. The URI can use any protocol; it is up
+ // to the server how to interpret it.
+ URI string `json:"uri"`
+ // Arguments to pass to the resource handler
+ Arguments map[string]any `json:"arguments,omitempty"`
}
// ReadResourceResult is the server's response to a resources/read request
@@ -453,11 +597,13 @@ type ResourceListChangedNotification struct {
// notifications from the server whenever a particular resource changes.
type SubscribeRequest struct {
Request
- Params struct {
- // The URI of the resource to subscribe to. The URI can use any protocol; it
- // is up to the server how to interpret it.
- URI string `json:"uri"`
- } `json:"params"`
+ Params SubscribeParams `json:"params"`
+}
+
+type SubscribeParams struct {
+ // The URI of the resource to subscribe to. The URI can use any protocol; it
+ // is up to the server how to interpret it.
+ URI string `json:"uri"`
}
// UnsubscribeRequest is sent from the client to request cancellation of
@@ -465,10 +611,12 @@ type SubscribeRequest struct {
// resources/subscribe request.
type UnsubscribeRequest struct {
Request
- Params struct {
- // The URI of the resource to unsubscribe from.
- URI string `json:"uri"`
- } `json:"params"`
+ Params UnsubscribeParams `json:"params"`
+}
+
+type UnsubscribeParams struct {
+ // The URI of the resource to unsubscribe from.
+ URI string `json:"uri"`
}
// ResourceUpdatedNotification is a notification from the server to the client,
@@ -476,11 +624,12 @@ type UnsubscribeRequest struct {
// should only be sent if the client previously sent a resources/subscribe request.
type ResourceUpdatedNotification struct {
Notification
- Params struct {
- // The URI of the resource that has been updated. This might be a sub-
- // resource of the one that the client actually subscribed to.
- URI string `json:"uri"`
- } `json:"params"`
+ Params ResourceUpdatedNotificationParams `json:"params"`
+}
+type ResourceUpdatedNotificationParams struct {
+ // The URI of the resource that has been updated. This might be a sub-
+ // resource of the one that the client actually subscribed to.
+ URI string `json:"uri"`
}
// Resource represents a known resource that the server is capable of reading.
@@ -501,6 +650,11 @@ type Resource struct {
MIMEType string `json:"mimeType,omitempty"`
}
+// GetName returns the name of the resource.
+func (r Resource) GetName() string {
+ return r.Name
+}
+
// ResourceTemplate represents a template description for resources available
// on the server.
type ResourceTemplate struct {
@@ -522,6 +676,11 @@ type ResourceTemplate struct {
MIMEType string `json:"mimeType,omitempty"`
}
+// GetName returns the name of the resourceTemplate.
+func (rt ResourceTemplate) GetName() string {
+ return rt.Name
+}
+
// ResourceContents represents the contents of a specific resource or sub-
// resource.
type ResourceContents interface {
@@ -557,12 +716,14 @@ func (BlobResourceContents) isResourceContents() {}
// adjust logging.
type SetLevelRequest struct {
Request
- Params struct {
- // The level of logging that the client wants to receive from the server.
- // The server should send all logs at this level and higher (i.e., more severe) to
- // the client as notifications/logging/message.
- Level LoggingLevel `json:"level"`
- } `json:"params"`
+ Params SetLevelParams `json:"params"`
+}
+
+type SetLevelParams struct {
+ // The level of logging that the client wants to receive from the server.
+ // The server should send all logs at this level and higher (i.e., more severe) to
+ // the client as notifications/logging/message.
+ Level LoggingLevel `json:"level"`
}
// LoggingMessageNotification is a notification of a log message passed from
@@ -570,15 +731,17 @@ type SetLevelRequest struct {
// the server MAY decide which messages to send automatically.
type LoggingMessageNotification struct {
Notification
- Params struct {
- // The severity of this log message.
- Level LoggingLevel `json:"level"`
- // An optional name of the logger issuing this message.
- Logger string `json:"logger,omitempty"`
- // The data to be logged, such as a string message or an object. Any JSON
- // serializable type is allowed here.
- Data interface{} `json:"data"`
- } `json:"params"`
+ Params LoggingMessageNotificationParams `json:"params"`
+}
+
+type LoggingMessageNotificationParams struct {
+ // The severity of this log message.
+ Level LoggingLevel `json:"level"`
+ // An optional name of the logger issuing this message.
+ Logger string `json:"logger,omitempty"`
+ // The data to be logged, such as a string message or an object. Any JSON
+ // serializable type is allowed here.
+ Data any `json:"data"`
}
// LoggingLevel represents the severity of a log message.
@@ -606,16 +769,18 @@ const (
// the request (human in the loop) and decide whether to approve it.
type CreateMessageRequest struct {
Request
- Params struct {
- Messages []SamplingMessage `json:"messages"`
- ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"`
- SystemPrompt string `json:"systemPrompt,omitempty"`
- IncludeContext string `json:"includeContext,omitempty"`
- Temperature float64 `json:"temperature,omitempty"`
- MaxTokens int `json:"maxTokens"`
- StopSequences []string `json:"stopSequences,omitempty"`
- Metadata interface{} `json:"metadata,omitempty"`
- } `json:"params"`
+ CreateMessageParams `json:"params"`
+}
+
+type CreateMessageParams struct {
+ Messages []SamplingMessage `json:"messages"`
+ ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"`
+ SystemPrompt string `json:"systemPrompt,omitempty"`
+ IncludeContext string `json:"includeContext,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ MaxTokens int `json:"maxTokens"`
+ StopSequences []string `json:"stopSequences,omitempty"`
+ Metadata any `json:"metadata,omitempty"`
}
// CreateMessageResult is the client's response to a sampling/create_message
@@ -633,28 +798,30 @@ type CreateMessageResult struct {
// SamplingMessage describes a message issued to or received from an LLM API.
type SamplingMessage struct {
- Role Role `json:"role"`
- Content interface{} `json:"content"` // Can be TextContent or ImageContent
+ Role Role `json:"role"`
+ Content any `json:"content"` // Can be TextContent, ImageContent or AudioContent
+}
+
+type Annotations struct {
+ // Describes who the intended customer of this object or data is.
+ //
+ // It can include multiple entries to indicate content useful for multiple
+ // audiences (e.g., `["user", "assistant"]`).
+ Audience []Role `json:"audience,omitempty"`
+
+ // Describes how important this data is for operating the server.
+ //
+ // A value of 1 means "most important," and indicates that the data is
+ // effectively required, while 0 means "least important," and indicates that
+ // the data is entirely optional.
+ Priority float64 `json:"priority,omitempty"`
}
// Annotated is the base for objects that include optional annotations for the
// client. The client can use annotations to inform how objects are used or
// displayed
type Annotated struct {
- Annotations *struct {
- // Describes who the intended customer of this object or data is.
- //
- // It can include multiple entries to indicate content useful for multiple
- // audiences (e.g., `["user", "assistant"]`).
- Audience []Role `json:"audience,omitempty"`
-
- // Describes how important this data is for operating the server.
- //
- // A value of 1 means "most important," and indicates that the data is
- // effectively required, while 0 means "least important," and indicates that
- // the data is entirely optional.
- Priority float64 `json:"priority,omitempty"`
- } `json:"annotations,omitempty"`
+ Annotations *Annotations `json:"annotations,omitempty"`
}
type Content interface {
@@ -685,6 +852,19 @@ type ImageContent struct {
func (ImageContent) isContent() {}
+// AudioContent represents the contents of audio, embedded into a prompt or tool call result.
+// It must have Type set to "audio".
+type AudioContent struct {
+ Annotated
+ Type string `json:"type"` // Must be "audio"
+ // The base64-encoded audio data.
+ Data string `json:"data"`
+ // The MIME type of the audio. Different providers may support different audio types.
+ MIMEType string `json:"mimeType"`
+}
+
+func (AudioContent) isContent() {}
+
// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result.
//
// It is up to the client how best to render embedded resources for the
@@ -758,15 +938,17 @@ type ModelHint struct {
// CompleteRequest is a request from the client to the server, to ask for completion options.
type CompleteRequest struct {
Request
- Params struct {
- Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference
- Argument struct {
- // The name of the argument
- Name string `json:"name"`
- // The value of the argument to use for completion matching.
- Value string `json:"value"`
- } `json:"argument"`
- } `json:"params"`
+ Params CompleteParams `json:"params"`
+}
+
+type CompleteParams struct {
+ Ref any `json:"ref"` // Can be PromptReference or ResourceReference
+ Argument struct {
+ // The name of the argument
+ Name string `json:"name"`
+ // The value of the argument to use for completion matching.
+ Value string `json:"value"`
+ } `json:"argument"`
}
// CompleteResult is the server's response to a completion/complete request
@@ -839,22 +1021,24 @@ type RootsListChangedNotification struct {
Notification
}
-/* Client messages */
// ClientRequest represents any request that can be sent from client to server.
-type ClientRequest interface{}
+type ClientRequest any
// ClientNotification represents any notification that can be sent from client to server.
-type ClientNotification interface{}
+type ClientNotification any
// ClientResult represents any result that can be sent from client to server.
-type ClientResult interface{}
+type ClientResult any
-/* Server messages */
// ServerRequest represents any request that can be sent from server to client.
-type ServerRequest interface{}
+type ServerRequest any
// ServerNotification represents any notification that can be sent from server to client.
-type ServerNotification interface{}
+type ServerNotification any
// ServerResult represents any result that can be sent from server to client.
-type ServerResult interface{}
+type ServerResult any
+
+type Named interface {
+ GetName() string
+}
@@ -3,6 +3,8 @@ package mcp
import (
"encoding/json"
"fmt"
+
+ "github.com/spf13/cast"
)
// ClientRequest types
@@ -58,7 +60,7 @@ var _ ServerResult = &ListToolsResult{}
// Helper functions for type assertions
// asType attempts to cast the given interface to the given type
-func asType[T any](content interface{}) (*T, bool) {
+func asType[T any](content any) (*T, bool) {
tc, ok := content.(T)
if !ok {
return nil, false
@@ -67,27 +69,32 @@ func asType[T any](content interface{}) (*T, bool) {
}
// AsTextContent attempts to cast the given interface to TextContent
-func AsTextContent(content interface{}) (*TextContent, bool) {
+func AsTextContent(content any) (*TextContent, bool) {
return asType[TextContent](content)
}
// AsImageContent attempts to cast the given interface to ImageContent
-func AsImageContent(content interface{}) (*ImageContent, bool) {
+func AsImageContent(content any) (*ImageContent, bool) {
return asType[ImageContent](content)
}
+// AsAudioContent attempts to cast the given interface to AudioContent
+func AsAudioContent(content any) (*AudioContent, bool) {
+ return asType[AudioContent](content)
+}
+
// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource
-func AsEmbeddedResource(content interface{}) (*EmbeddedResource, bool) {
+func AsEmbeddedResource(content any) (*EmbeddedResource, bool) {
return asType[EmbeddedResource](content)
}
// AsTextResourceContents attempts to cast the given interface to TextResourceContents
-func AsTextResourceContents(content interface{}) (*TextResourceContents, bool) {
+func AsTextResourceContents(content any) (*TextResourceContents, bool) {
return asType[TextResourceContents](content)
}
// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents
-func AsBlobResourceContents(content interface{}) (*BlobResourceContents, bool) {
+func AsBlobResourceContents(content any) (*BlobResourceContents, bool) {
return asType[BlobResourceContents](content)
}
@@ -107,15 +114,15 @@ func NewJSONRPCError(
id RequestId,
code int,
message string,
- data interface{},
+ data any,
) JSONRPCError {
return JSONRPCError{
JSONRPC: JSONRPC_VERSION,
ID: id,
Error: struct {
- Code int `json:"code"`
- Message string `json:"message"`
- Data interface{} `json:"data,omitempty"`
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data any `json:"data,omitempty"`
}{
Code: code,
Message: message,
@@ -124,11 +131,13 @@ func NewJSONRPCError(
}
}
+// NewProgressNotification
// Helper function for creating a progress notification
func NewProgressNotification(
token ProgressToken,
progress float64,
total *float64,
+ message *string,
) ProgressNotification {
notification := ProgressNotification{
Notification: Notification{
@@ -138,6 +147,7 @@ func NewProgressNotification(
ProgressToken ProgressToken `json:"progressToken"`
Progress float64 `json:"progress"`
Total float64 `json:"total,omitempty"`
+ Message string `json:"message,omitempty"`
}{
ProgressToken: token,
Progress: progress,
@@ -146,14 +156,18 @@ func NewProgressNotification(
if total != nil {
notification.Params.Total = *total
}
+ if message != nil {
+ notification.Params.Message = *message
+ }
return notification
}
+// NewLoggingMessageNotification
// Helper function for creating a logging message notification
func NewLoggingMessageNotification(
level LoggingLevel,
logger string,
- data interface{},
+ data any,
) LoggingMessageNotification {
return LoggingMessageNotification{
Notification: Notification{
@@ -162,7 +176,7 @@ func NewLoggingMessageNotification(
Params: struct {
Level LoggingLevel `json:"level"`
Logger string `json:"logger,omitempty"`
- Data interface{} `json:"data"`
+ Data any `json:"data"`
}{
Level: level,
Logger: logger,
@@ -171,6 +185,7 @@ func NewLoggingMessageNotification(
}
}
+// NewPromptMessage
// Helper function to create a new PromptMessage
func NewPromptMessage(role Role, content Content) PromptMessage {
return PromptMessage{
@@ -179,6 +194,7 @@ func NewPromptMessage(role Role, content Content) PromptMessage {
}
}
+// NewTextContent
// Helper function to create a new TextContent
func NewTextContent(text string) TextContent {
return TextContent{
@@ -187,6 +203,7 @@ func NewTextContent(text string) TextContent {
}
}
+// NewImageContent
// Helper function to create a new ImageContent
func NewImageContent(data, mimeType string) ImageContent {
return ImageContent{
@@ -196,6 +213,15 @@ func NewImageContent(data, mimeType string) ImageContent {
}
}
+// Helper function to create a new AudioContent
+func NewAudioContent(data, mimeType string) AudioContent {
+ return AudioContent{
+ Type: "audio",
+ Data: data,
+ MIMEType: mimeType,
+ }
+}
+
// Helper function to create a new EmbeddedResource
func NewEmbeddedResource(resource ResourceContents) EmbeddedResource {
return EmbeddedResource{
@@ -233,6 +259,23 @@ func NewToolResultImage(text, imageData, mimeType string) *CallToolResult {
}
}
+// NewToolResultAudio creates a new CallToolResult with both text and audio content
+func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult {
+ return &CallToolResult{
+ Content: []Content{
+ TextContent{
+ Type: "text",
+ Text: text,
+ },
+ AudioContent{
+ Type: "audio",
+ Data: imageData,
+ MIMEType: mimeType,
+ },
+ },
+ }
+}
+
// NewToolResultResource creates a new CallToolResult with an embedded resource
func NewToolResultResource(
text string,
@@ -252,6 +295,53 @@ func NewToolResultResource(
}
}
+// NewToolResultError creates a new CallToolResult with an error message.
+// Any errors that originate from the tool SHOULD be reported inside the result object.
+func NewToolResultError(text string) *CallToolResult {
+ return &CallToolResult{
+ Content: []Content{
+ TextContent{
+ Type: "text",
+ Text: text,
+ },
+ },
+ IsError: true,
+ }
+}
+
+// NewToolResultErrorFromErr creates a new CallToolResult with an error message.
+// If an error is provided, its details will be appended to the text message.
+// Any errors that originate from the tool SHOULD be reported inside the result object.
+func NewToolResultErrorFromErr(text string, err error) *CallToolResult {
+ if err != nil {
+ text = fmt.Sprintf("%s: %v", text, err)
+ }
+ return &CallToolResult{
+ Content: []Content{
+ TextContent{
+ Type: "text",
+ Text: text,
+ },
+ },
+ IsError: true,
+ }
+}
+
+// NewToolResultErrorf creates a new CallToolResult with an error message.
+// The error message is formatted using the fmt package.
+// Any errors that originate from the tool SHOULD be reported inside the result object.
+func NewToolResultErrorf(format string, a ...any) *CallToolResult {
+ return &CallToolResult{
+ Content: []Content{
+ TextContent{
+ Type: "text",
+ Text: fmt.Sprintf(format, a...),
+ },
+ },
+ IsError: true,
+ }
+}
+
// NewListResourcesResult creates a new ListResourcesResult
func NewListResourcesResult(
resources []Resource,
@@ -338,6 +428,7 @@ func NewInitializeResult(
}
}
+// FormatNumberResult
// Helper for formatting numbers in tool results
func FormatNumberResult(value float64) *CallToolResult {
return NewToolResultText(fmt.Sprintf("%.2f", value))
@@ -367,9 +458,6 @@ func ParseContent(contentMap map[string]any) (Content, error) {
switch contentType {
case "text":
text := ExtractString(contentMap, "text")
- if text == "" {
- return nil, fmt.Errorf("text is missing")
- }
return NewTextContent(text), nil
case "image":
@@ -380,6 +468,14 @@ func ParseContent(contentMap map[string]any) (Content, error) {
}
return NewImageContent(data, mimeType), nil
+ case "audio":
+ data := ExtractString(contentMap, "data")
+ mimeType := ExtractString(contentMap, "mimeType")
+ if data == "" || mimeType == "" {
+ return nil, fmt.Errorf("audio data or mimeType is missing")
+ }
+ return NewAudioContent(data, mimeType), nil
+
case "resource":
resourceMap := ExtractMap(contentMap, "resource")
if resourceMap == nil {
@@ -398,6 +494,10 @@ func ParseContent(contentMap map[string]any) (Content, error) {
}
func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) {
+ if rawMessage == nil {
+ return nil, fmt.Errorf("response is nil")
+ }
+
var jsonContent map[string]any
if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
@@ -460,6 +560,10 @@ func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error)
}
func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) {
+ if rawMessage == nil {
+ return nil, fmt.Errorf("response is nil")
+ }
+
var jsonContent map[string]any
if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
@@ -538,6 +642,10 @@ func ParseResourceContents(contentMap map[string]any) (ResourceContents, error)
}
func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) {
+ if rawMessage == nil {
+ return nil, fmt.Errorf("response is nil")
+ }
+
var jsonContent map[string]any
if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
@@ -580,3 +688,111 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult,
return &result, nil
}
+
+func ParseArgument(request CallToolRequest, key string, defaultVal any) any {
+ args := request.GetArguments()
+ if _, ok := args[key]; !ok {
+ return defaultVal
+ } else {
+ return args[key]
+ }
+}
+
+// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest.
+// If the key is not found in the Arguments map, the defaultValue is returned.
+// The function uses cast.ToBool for conversion which handles various string representations
+// such as "true", "yes", "1", etc.
+func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToBool(v)
+}
+
+// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest.
+// If the key is not found in the Arguments map, the defaultValue is returned.
+func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToInt64(v)
+}
+
+// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest.
+func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToInt32(v)
+}
+
+// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest.
+func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToInt16(v)
+}
+
+// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest.
+func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToInt8(v)
+}
+
+// ParseInt extracts and converts an int parameter from a CallToolRequest.
+func ParseInt(request CallToolRequest, key string, defaultValue int) int {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToInt(v)
+}
+
+// ParseUInt extracts and converts an uint parameter from a CallToolRequest.
+func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToUint(v)
+}
+
+// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest.
+func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToUint64(v)
+}
+
+// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest.
+func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToUint32(v)
+}
+
+// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest.
+func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToUint16(v)
+}
+
+// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest.
+func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToUint8(v)
+}
+
+// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest.
+func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToFloat32(v)
+}
+
+// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest.
+func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToFloat64(v)
+}
+
+// ParseString extracts and converts a string parameter from a CallToolRequest.
+func ParseString(request CallToolRequest, key string, defaultValue string) string {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToString(v)
+}
+
+// ParseStringMap extracts and converts a string map parameter from a CallToolRequest.
+func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any {
+ v := ParseArgument(request, key, defaultValue)
+ return cast.ToStringMap(v)
+}
+
+// ToBoolPtr returns a pointer to the given boolean value
+func ToBoolPtr(b bool) *bool {
+ return &b
+}
@@ -0,0 +1,34 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+)
+
+var (
+ // Common server errors
+ ErrUnsupported = errors.New("not supported")
+ ErrResourceNotFound = errors.New("resource not found")
+ ErrPromptNotFound = errors.New("prompt not found")
+ ErrToolNotFound = errors.New("tool not found")
+
+ // Session-related errors
+ ErrSessionNotFound = errors.New("session not found")
+ ErrSessionExists = errors.New("session already exists")
+ ErrSessionNotInitialized = errors.New("session not properly initialized")
+ ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
+ ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")
+
+ // Notification-related errors
+ ErrNotificationNotInitialized = errors.New("notification channel not initialized")
+ ErrNotificationChannelBlocked = errors.New("notification channel full or blocked")
+)
+
+// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration
+type ErrDynamicPathConfig struct {
+ Method string
+}
+
+func (e *ErrDynamicPathConfig) Error() string {
+ return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method)
+}
@@ -0,0 +1,532 @@
+// Code generated by `go generate`. DO NOT EDIT.
+// source: server/internal/gen/hooks.go.tmpl
+package server
+
+import (
+ "context"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered.
+type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession)
+
+// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered.
+type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession)
+
+// BeforeAnyHookFunc is a function that is called after the request is
+// parsed but before the method is called.
+type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any)
+
+// OnSuccessHookFunc is a hook that will be called after the request
+// successfully generates a result, but before the result is sent to the client.
+type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any)
+
+// OnErrorHookFunc is a hook that will be called when an error occurs,
+// either during the request parsing or the method execution.
+//
+// Example usage:
+// ```
+//
+// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
+// // Check for specific error types using errors.Is
+// if errors.Is(err, ErrUnsupported) {
+// // Handle capability not supported errors
+// log.Printf("Capability not supported: %v", err)
+// }
+//
+// // Use errors.As to get specific error types
+// var parseErr = &UnparsableMessageError{}
+// if errors.As(err, &parseErr) {
+// // Access specific methods/fields of the error type
+// log.Printf("Failed to parse message for method %s: %v",
+// parseErr.GetMethod(), parseErr.Unwrap())
+// // Access the raw message that failed to parse
+// rawMsg := parseErr.GetMessage()
+// }
+//
+// // Check for specific resource/prompt/tool errors
+// switch {
+// case errors.Is(err, ErrResourceNotFound):
+// log.Printf("Resource not found: %v", err)
+// case errors.Is(err, ErrPromptNotFound):
+// log.Printf("Prompt not found: %v", err)
+// case errors.Is(err, ErrToolNotFound):
+// log.Printf("Tool not found: %v", err)
+// }
+// })
+type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error)
+
+// OnRequestInitializationFunc is a function that called before handle diff request method
+// Should any errors arise during func execution, the service will promptly return the corresponding error message.
+type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error
+
+type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest)
+type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult)
+
+type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest)
+type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult)
+
+type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest)
+type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult)
+
+type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest)
+type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult)
+
+type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest)
+type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult)
+
+type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest)
+type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult)
+
+type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest)
+type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult)
+
+type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest)
+type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult)
+
+type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest)
+type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult)
+
+type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest)
+type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult)
+
+type Hooks struct {
+ OnRegisterSession []OnRegisterSessionHookFunc
+ OnUnregisterSession []OnUnregisterSessionHookFunc
+ OnBeforeAny []BeforeAnyHookFunc
+ OnSuccess []OnSuccessHookFunc
+ OnError []OnErrorHookFunc
+ OnRequestInitialization []OnRequestInitializationFunc
+ OnBeforeInitialize []OnBeforeInitializeFunc
+ OnAfterInitialize []OnAfterInitializeFunc
+ OnBeforePing []OnBeforePingFunc
+ OnAfterPing []OnAfterPingFunc
+ OnBeforeSetLevel []OnBeforeSetLevelFunc
+ OnAfterSetLevel []OnAfterSetLevelFunc
+ OnBeforeListResources []OnBeforeListResourcesFunc
+ OnAfterListResources []OnAfterListResourcesFunc
+ OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc
+ OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc
+ OnBeforeReadResource []OnBeforeReadResourceFunc
+ OnAfterReadResource []OnAfterReadResourceFunc
+ OnBeforeListPrompts []OnBeforeListPromptsFunc
+ OnAfterListPrompts []OnAfterListPromptsFunc
+ OnBeforeGetPrompt []OnBeforeGetPromptFunc
+ OnAfterGetPrompt []OnAfterGetPromptFunc
+ OnBeforeListTools []OnBeforeListToolsFunc
+ OnAfterListTools []OnAfterListToolsFunc
+ OnBeforeCallTool []OnBeforeCallToolFunc
+ OnAfterCallTool []OnAfterCallToolFunc
+}
+
+func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) {
+ c.OnBeforeAny = append(c.OnBeforeAny, hook)
+}
+
+func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) {
+ c.OnSuccess = append(c.OnSuccess, hook)
+}
+
+// AddOnError registers a hook function that will be called when an error occurs.
+// The error parameter contains the actual error object, which can be interrogated
+// using Go's error handling patterns like errors.Is and errors.As.
+//
+// Example:
+// ```
+// // Create a channel to receive errors for testing
+// errChan := make(chan error, 1)
+//
+// // Register hook to capture and inspect errors
+// hooks := &Hooks{}
+//
+// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
+// // For capability-related errors
+// if errors.Is(err, ErrUnsupported) {
+// // Handle capability not supported
+// errChan <- err
+// return
+// }
+//
+// // For parsing errors
+// var parseErr = &UnparsableMessageError{}
+// if errors.As(err, &parseErr) {
+// // Handle unparsable message errors
+// fmt.Printf("Failed to parse %s request: %v\n",
+// parseErr.GetMethod(), parseErr.Unwrap())
+// errChan <- parseErr
+// return
+// }
+//
+// // For resource/prompt/tool not found errors
+// if errors.Is(err, ErrResourceNotFound) ||
+// errors.Is(err, ErrPromptNotFound) ||
+// errors.Is(err, ErrToolNotFound) {
+// // Handle not found errors
+// errChan <- err
+// return
+// }
+//
+// // For other errors
+// errChan <- err
+// })
+//
+// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks))
+// ```
+func (c *Hooks) AddOnError(hook OnErrorHookFunc) {
+ c.OnError = append(c.OnError, hook)
+}
+
+func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) {
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeAny {
+ hook(ctx, id, method, message)
+ }
+}
+
+func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) {
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnSuccess {
+ hook(ctx, id, method, message, result)
+ }
+}
+
+// onError calls all registered error hooks with the error object.
+// The err parameter contains the actual error that occurred, which implements
+// the standard error interface and may be a wrapped error or custom error type.
+//
+// This allows consumer code to use Go's error handling patterns:
+// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors
+// - errors.As(err, &customErr) to extract custom error types
+//
+// Common error types include:
+// - ErrUnsupported: When a capability is not enabled
+// - UnparsableMessageError: When request parsing fails
+// - ErrResourceNotFound: When a resource is not found
+// - ErrPromptNotFound: When a prompt is not found
+// - ErrToolNotFound: When a tool is not found
+func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnError {
+ hook(ctx, id, method, message, err)
+ }
+}
+
+func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) {
+ c.OnRegisterSession = append(c.OnRegisterSession, hook)
+}
+
+func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) {
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnRegisterSession {
+ hook(ctx, session)
+ }
+}
+
+func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) {
+ c.OnUnregisterSession = append(c.OnUnregisterSession, hook)
+}
+
+func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) {
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnUnregisterSession {
+ hook(ctx, session)
+ }
+}
+
+func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) {
+ c.OnRequestInitialization = append(c.OnRequestInitialization, hook)
+}
+
+func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error {
+ if c == nil {
+ return nil
+ }
+ for _, hook := range c.OnRequestInitialization {
+ err := hook(ctx, id, message)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) {
+ c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook)
+}
+
+func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) {
+ c.OnAfterInitialize = append(c.OnAfterInitialize, hook)
+}
+
+func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) {
+ c.beforeAny(ctx, id, mcp.MethodInitialize, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeInitialize {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
+ c.onSuccess(ctx, id, mcp.MethodInitialize, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterInitialize {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) {
+ c.OnBeforePing = append(c.OnBeforePing, hook)
+}
+
+func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) {
+ c.OnAfterPing = append(c.OnAfterPing, hook)
+}
+
+func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) {
+ c.beforeAny(ctx, id, mcp.MethodPing, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforePing {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) {
+ c.onSuccess(ctx, id, mcp.MethodPing, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterPing {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) {
+ c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook)
+}
+
+func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) {
+ c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook)
+}
+
+func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) {
+ c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeSetLevel {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) {
+ c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterSetLevel {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) {
+ c.OnBeforeListResources = append(c.OnBeforeListResources, hook)
+}
+
+func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) {
+ c.OnAfterListResources = append(c.OnAfterListResources, hook)
+}
+
+func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) {
+ c.beforeAny(ctx, id, mcp.MethodResourcesList, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeListResources {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) {
+ c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterListResources {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) {
+ c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook)
+}
+
+func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) {
+ c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook)
+}
+
+func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) {
+ c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeListResourceTemplates {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) {
+ c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterListResourceTemplates {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) {
+ c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook)
+}
+
+func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) {
+ c.OnAfterReadResource = append(c.OnAfterReadResource, hook)
+}
+
+func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) {
+ c.beforeAny(ctx, id, mcp.MethodResourcesRead, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeReadResource {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) {
+ c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterReadResource {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) {
+ c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook)
+}
+
+func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) {
+ c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook)
+}
+
+func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) {
+ c.beforeAny(ctx, id, mcp.MethodPromptsList, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeListPrompts {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) {
+ c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterListPrompts {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) {
+ c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook)
+}
+
+func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) {
+ c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook)
+}
+
+func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) {
+ c.beforeAny(ctx, id, mcp.MethodPromptsGet, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeGetPrompt {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) {
+ c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterGetPrompt {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) {
+ c.OnBeforeListTools = append(c.OnBeforeListTools, hook)
+}
+
+func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) {
+ c.OnAfterListTools = append(c.OnAfterListTools, hook)
+}
+
+func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) {
+ c.beforeAny(ctx, id, mcp.MethodToolsList, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeListTools {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) {
+ c.onSuccess(ctx, id, mcp.MethodToolsList, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterListTools {
+ hook(ctx, id, message, result)
+ }
+}
+func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) {
+ c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook)
+}
+
+func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) {
+ c.OnAfterCallTool = append(c.OnAfterCallTool, hook)
+}
+
+func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) {
+ c.beforeAny(ctx, id, mcp.MethodToolsCall, message)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnBeforeCallTool {
+ hook(ctx, id, message)
+ }
+}
+
+func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
+ c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result)
+ if c == nil {
+ return
+ }
+ for _, hook := range c.OnAfterCallTool {
+ hook(ctx, id, message, result)
+ }
+}
@@ -0,0 +1,11 @@
+package server
+
+import (
+ "context"
+ "net/http"
+)
+
+// HTTPContextFunc is a function that takes an existing context and the current
+// request and returns a potentially modified context based on the request
+// content. This can be used to inject context values from headers, for example.
+type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context
@@ -0,0 +1,320 @@
+// Code generated by `go generate`. DO NOT EDIT.
+// source: server/internal/gen/request_handler.go.tmpl
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response
+func (s *MCPServer) HandleMessage(
+ ctx context.Context,
+ message json.RawMessage,
+) mcp.JSONRPCMessage {
+ // Add server to context
+ ctx = context.WithValue(ctx, serverKey{}, s)
+ var err *requestError
+
+ var baseMessage struct {
+ JSONRPC string `json:"jsonrpc"`
+ Method mcp.MCPMethod `json:"method"`
+ ID any `json:"id,omitempty"`
+ Result any `json:"result,omitempty"`
+ }
+
+ if err := json.Unmarshal(message, &baseMessage); err != nil {
+ return createErrorResponse(
+ nil,
+ mcp.PARSE_ERROR,
+ "Failed to parse message",
+ )
+ }
+
+ // Check for valid JSONRPC version
+ if baseMessage.JSONRPC != mcp.JSONRPC_VERSION {
+ return createErrorResponse(
+ baseMessage.ID,
+ mcp.INVALID_REQUEST,
+ "Invalid JSON-RPC version",
+ )
+ }
+
+ if baseMessage.ID == nil {
+ var notification mcp.JSONRPCNotification
+ if err := json.Unmarshal(message, ¬ification); err != nil {
+ return createErrorResponse(
+ nil,
+ mcp.PARSE_ERROR,
+ "Failed to parse notification",
+ )
+ }
+ s.handleNotification(ctx, notification)
+ return nil // Return nil for notifications
+ }
+
+ if baseMessage.Result != nil {
+ // this is a response to a request sent by the server (e.g. from a ping
+ // sent due to WithKeepAlive option)
+ return nil
+ }
+
+ handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message)
+ if handleErr != nil {
+ return createErrorResponse(
+ baseMessage.ID,
+ mcp.INVALID_REQUEST,
+ handleErr.Error(),
+ )
+ }
+
+ switch baseMessage.Method {
+ case mcp.MethodInitialize:
+ var request mcp.InitializeRequest
+ var result *mcp.InitializeResult
+ if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeInitialize(ctx, baseMessage.ID, &request)
+ result, err = s.handleInitialize(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodPing:
+ var request mcp.PingRequest
+ var result *mcp.EmptyResult
+ if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforePing(ctx, baseMessage.ID, &request)
+ result, err = s.handlePing(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterPing(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodSetLogLevel:
+ var request mcp.SetLevelRequest
+ var result *mcp.EmptyResult
+ if s.capabilities.logging == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("logging %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request)
+ result, err = s.handleSetLevel(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodResourcesList:
+ var request mcp.ListResourcesRequest
+ var result *mcp.ListResourcesResult
+ if s.capabilities.resources == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("resources %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeListResources(ctx, baseMessage.ID, &request)
+ result, err = s.handleListResources(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterListResources(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodResourcesTemplatesList:
+ var request mcp.ListResourceTemplatesRequest
+ var result *mcp.ListResourceTemplatesResult
+ if s.capabilities.resources == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("resources %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request)
+ result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodResourcesRead:
+ var request mcp.ReadResourceRequest
+ var result *mcp.ReadResourceResult
+ if s.capabilities.resources == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("resources %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeReadResource(ctx, baseMessage.ID, &request)
+ result, err = s.handleReadResource(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodPromptsList:
+ var request mcp.ListPromptsRequest
+ var result *mcp.ListPromptsResult
+ if s.capabilities.prompts == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("prompts %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request)
+ result, err = s.handleListPrompts(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodPromptsGet:
+ var request mcp.GetPromptRequest
+ var result *mcp.GetPromptResult
+ if s.capabilities.prompts == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("prompts %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request)
+ result, err = s.handleGetPrompt(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodToolsList:
+ var request mcp.ListToolsRequest
+ var result *mcp.ListToolsResult
+ if s.capabilities.tools == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("tools %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeListTools(ctx, baseMessage.ID, &request)
+ result, err = s.handleListTools(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterListTools(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ case mcp.MethodToolsCall:
+ var request mcp.CallToolRequest
+ var result *mcp.CallToolResult
+ if s.capabilities.tools == nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.METHOD_NOT_FOUND,
+ err: fmt.Errorf("tools %w", ErrUnsupported),
+ }
+ } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
+ err = &requestError{
+ id: baseMessage.ID,
+ code: mcp.INVALID_REQUEST,
+ err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
+ }
+ } else {
+ s.hooks.beforeCallTool(ctx, baseMessage.ID, &request)
+ result, err = s.handleToolCall(ctx, baseMessage.ID, request)
+ }
+ if err != nil {
+ s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
+ return err.ToJSONRPCError()
+ }
+ s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result)
+ return createResponse(baseMessage.ID, *result)
+ default:
+ return createErrorResponse(
+ baseMessage.ID,
+ mcp.METHOD_NOT_FOUND,
+ fmt.Sprintf("Method %s not found", baseMessage.Method),
+ )
+ }
+}
@@ -0,0 +1,1079 @@
+// Package server provides MCP (Model Context Protocol) server implementations.
+package server
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "slices"
+ "sort"
+ "sync"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// resourceEntry holds both a resource and its handler
+type resourceEntry struct {
+ resource mcp.Resource
+ handler ResourceHandlerFunc
+}
+
+// resourceTemplateEntry holds both a template and its handler
+type resourceTemplateEntry struct {
+ template mcp.ResourceTemplate
+ handler ResourceTemplateHandlerFunc
+}
+
+// ServerOption is a function that configures an MCPServer.
+type ServerOption func(*MCPServer)
+
+// ResourceHandlerFunc is a function that returns resource contents.
+type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error)
+
+// ResourceTemplateHandlerFunc is a function that returns a resource template.
+type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error)
+
+// PromptHandlerFunc handles prompt requests with given arguments.
+type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
+
+// ToolHandlerFunc handles tool calls with given arguments.
+type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
+
+// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc.
+type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc
+
+// ToolFilterFunc is a function that filters tools based on context, typically using session information.
+type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool
+
+// ServerTool combines a Tool with its ToolHandlerFunc.
+type ServerTool struct {
+ Tool mcp.Tool
+ Handler ToolHandlerFunc
+}
+
+// ServerPrompt combines a Prompt with its handler function.
+type ServerPrompt struct {
+ Prompt mcp.Prompt
+ Handler PromptHandlerFunc
+}
+
+// ServerResource combines a Resource with its handler function.
+type ServerResource struct {
+ Resource mcp.Resource
+ Handler ResourceHandlerFunc
+}
+
+// serverKey is the context key for storing the server instance
+type serverKey struct{}
+
+// ServerFromContext retrieves the MCPServer instance from a context
+func ServerFromContext(ctx context.Context) *MCPServer {
+ if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok {
+ return srv
+ }
+ return nil
+}
+
+// UnparsableMessageError is attached to the RequestError when json.Unmarshal
+// fails on the request.
+type UnparsableMessageError struct {
+ message json.RawMessage
+ method mcp.MCPMethod
+ err error
+}
+
+func (e *UnparsableMessageError) Error() string {
+ return fmt.Sprintf("unparsable %s request: %s", e.method, e.err)
+}
+
+func (e *UnparsableMessageError) Unwrap() error {
+ return e.err
+}
+
+func (e *UnparsableMessageError) GetMessage() json.RawMessage {
+ return e.message
+}
+
+func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod {
+ return e.method
+}
+
+// RequestError is an error that can be converted to a JSON-RPC error.
+// Implements Unwrap() to allow inspecting the error chain.
+type requestError struct {
+ id any
+ code int
+ err error
+}
+
+func (e *requestError) Error() string {
+ return fmt.Sprintf("request error: %s", e.err)
+}
+
+func (e *requestError) ToJSONRPCError() mcp.JSONRPCError {
+ return mcp.JSONRPCError{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: mcp.NewRequestId(e.id),
+ Error: struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data any `json:"data,omitempty"`
+ }{
+ Code: e.code,
+ Message: e.err.Error(),
+ },
+ }
+}
+
+func (e *requestError) Unwrap() error {
+ return e.err
+}
+
+// NotificationHandlerFunc handles incoming notifications.
+type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification)
+
+// MCPServer implements a Model Context Protocol server that can handle various types of requests
+// including resources, prompts, and tools.
+type MCPServer struct {
+ // Separate mutexes for different resource types
+ resourcesMu sync.RWMutex
+ promptsMu sync.RWMutex
+ toolsMu sync.RWMutex
+ middlewareMu sync.RWMutex
+ notificationHandlersMu sync.RWMutex
+ capabilitiesMu sync.RWMutex
+ toolFiltersMu sync.RWMutex
+
+ name string
+ version string
+ instructions string
+ resources map[string]resourceEntry
+ resourceTemplates map[string]resourceTemplateEntry
+ prompts map[string]mcp.Prompt
+ promptHandlers map[string]PromptHandlerFunc
+ tools map[string]ServerTool
+ toolHandlerMiddlewares []ToolHandlerMiddleware
+ toolFilters []ToolFilterFunc
+ notificationHandlers map[string]NotificationHandlerFunc
+ capabilities serverCapabilities
+ paginationLimit *int
+ sessions sync.Map
+ hooks *Hooks
+}
+
+// WithPaginationLimit sets the pagination limit for the server.
+func WithPaginationLimit(limit int) ServerOption {
+ return func(s *MCPServer) {
+ s.paginationLimit = &limit
+ }
+}
+
+// serverCapabilities defines the supported features of the MCP server
+type serverCapabilities struct {
+ tools *toolCapabilities
+ resources *resourceCapabilities
+ prompts *promptCapabilities
+ logging *bool
+}
+
+// resourceCapabilities defines the supported resource-related features
+type resourceCapabilities struct {
+ subscribe bool
+ listChanged bool
+}
+
+// promptCapabilities defines the supported prompt-related features
+type promptCapabilities struct {
+ listChanged bool
+}
+
+// toolCapabilities defines the supported tool-related features
+type toolCapabilities struct {
+ listChanged bool
+}
+
+// WithResourceCapabilities configures resource-related server capabilities
+func WithResourceCapabilities(subscribe, listChanged bool) ServerOption {
+ return func(s *MCPServer) {
+ // Always create a non-nil capability object
+ s.capabilities.resources = &resourceCapabilities{
+ subscribe: subscribe,
+ listChanged: listChanged,
+ }
+ }
+}
+
+// WithToolHandlerMiddleware allows adding a middleware for the
+// tool handler call chain.
+func WithToolHandlerMiddleware(
+ toolHandlerMiddleware ToolHandlerMiddleware,
+) ServerOption {
+ return func(s *MCPServer) {
+ s.middlewareMu.Lock()
+ s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware)
+ s.middlewareMu.Unlock()
+ }
+}
+
+// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools
+func WithToolFilter(
+ toolFilter ToolFilterFunc,
+) ServerOption {
+ return func(s *MCPServer) {
+ s.toolFiltersMu.Lock()
+ s.toolFilters = append(s.toolFilters, toolFilter)
+ s.toolFiltersMu.Unlock()
+ }
+}
+
+// WithRecovery adds a middleware that recovers from panics in tool handlers.
+func WithRecovery() ServerOption {
+ return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc {
+ return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf(
+ "panic recovered in %s tool handler: %v",
+ request.Params.Name,
+ r,
+ )
+ }
+ }()
+ return next(ctx, request)
+ }
+ })
+}
+
+// WithHooks allows adding hooks that will be called before or after
+// either [all] requests or before / after specific request methods, or else
+// prior to returning an error to the client.
+func WithHooks(hooks *Hooks) ServerOption {
+ return func(s *MCPServer) {
+ s.hooks = hooks
+ }
+}
+
+// WithPromptCapabilities configures prompt-related server capabilities
+func WithPromptCapabilities(listChanged bool) ServerOption {
+ return func(s *MCPServer) {
+ // Always create a non-nil capability object
+ s.capabilities.prompts = &promptCapabilities{
+ listChanged: listChanged,
+ }
+ }
+}
+
+// WithToolCapabilities configures tool-related server capabilities
+func WithToolCapabilities(listChanged bool) ServerOption {
+ return func(s *MCPServer) {
+ // Always create a non-nil capability object
+ s.capabilities.tools = &toolCapabilities{
+ listChanged: listChanged,
+ }
+ }
+}
+
+// WithLogging enables logging capabilities for the server
+func WithLogging() ServerOption {
+ return func(s *MCPServer) {
+ s.capabilities.logging = mcp.ToBoolPtr(true)
+ }
+}
+
+// WithInstructions sets the server instructions for the client returned in the initialize response
+func WithInstructions(instructions string) ServerOption {
+ return func(s *MCPServer) {
+ s.instructions = instructions
+ }
+}
+
+// NewMCPServer creates a new MCP server instance with the given name, version and options
+func NewMCPServer(
+ name, version string,
+ opts ...ServerOption,
+) *MCPServer {
+ s := &MCPServer{
+ resources: make(map[string]resourceEntry),
+ resourceTemplates: make(map[string]resourceTemplateEntry),
+ prompts: make(map[string]mcp.Prompt),
+ promptHandlers: make(map[string]PromptHandlerFunc),
+ tools: make(map[string]ServerTool),
+ name: name,
+ version: version,
+ notificationHandlers: make(map[string]NotificationHandlerFunc),
+ capabilities: serverCapabilities{
+ tools: nil,
+ resources: nil,
+ prompts: nil,
+ logging: nil,
+ },
+ }
+
+ for _, opt := range opts {
+ opt(s)
+ }
+
+ return s
+}
+
+// AddResources registers multiple resources at once
+func (s *MCPServer) AddResources(resources ...ServerResource) {
+ s.implicitlyRegisterResourceCapabilities()
+
+ s.resourcesMu.Lock()
+ for _, entry := range resources {
+ s.resources[entry.Resource.URI] = resourceEntry{
+ resource: entry.Resource,
+ handler: entry.Handler,
+ }
+ }
+ s.resourcesMu.Unlock()
+
+ // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification
+ if s.capabilities.resources.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
+ }
+}
+
+// AddResource registers a new resource and its handler
+func (s *MCPServer) AddResource(
+ resource mcp.Resource,
+ handler ResourceHandlerFunc,
+) {
+ s.AddResources(ServerResource{Resource: resource, Handler: handler})
+}
+
+// RemoveResource removes a resource from the server
+func (s *MCPServer) RemoveResource(uri string) {
+ s.resourcesMu.Lock()
+ _, exists := s.resources[uri]
+ if exists {
+ delete(s.resources, uri)
+ }
+ s.resourcesMu.Unlock()
+
+ // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource
+ if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
+ s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
+ }
+}
+
+// AddResourceTemplate registers a new resource template and its handler
+func (s *MCPServer) AddResourceTemplate(
+ template mcp.ResourceTemplate,
+ handler ResourceTemplateHandlerFunc,
+) {
+ s.implicitlyRegisterResourceCapabilities()
+
+ s.resourcesMu.Lock()
+ s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{
+ template: template,
+ handler: handler,
+ }
+ s.resourcesMu.Unlock()
+
+ // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification
+ if s.capabilities.resources.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
+ }
+}
+
+// AddPrompts registers multiple prompts at once
+func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) {
+ s.implicitlyRegisterPromptCapabilities()
+
+ s.promptsMu.Lock()
+ for _, entry := range prompts {
+ s.prompts[entry.Prompt.Name] = entry.Prompt
+ s.promptHandlers[entry.Prompt.Name] = entry.Handler
+ }
+ s.promptsMu.Unlock()
+
+ // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification.
+ if s.capabilities.prompts.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
+ }
+}
+
+// AddPrompt registers a new prompt handler with the given name
+func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) {
+ s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler})
+}
+
+// DeletePrompts removes prompts from the server
+func (s *MCPServer) DeletePrompts(names ...string) {
+ s.promptsMu.Lock()
+ var exists bool
+ for _, name := range names {
+ if _, ok := s.prompts[name]; ok {
+ delete(s.prompts, name)
+ delete(s.promptHandlers, name)
+ exists = true
+ }
+ }
+ s.promptsMu.Unlock()
+
+ // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt
+ if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
+ }
+}
+
+// AddTool registers a new tool and its handler
+func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
+ s.AddTools(ServerTool{Tool: tool, Handler: handler})
+}
+
+// Register tool capabilities due to a tool being added. Default to
+// listChanged: true, but don't change the value if we've already explicitly
+// registered tools.listChanged false.
+func (s *MCPServer) implicitlyRegisterToolCapabilities() {
+ s.implicitlyRegisterCapabilities(
+ func() bool { return s.capabilities.tools != nil },
+ func() { s.capabilities.tools = &toolCapabilities{listChanged: true} },
+ )
+}
+
+func (s *MCPServer) implicitlyRegisterResourceCapabilities() {
+ s.implicitlyRegisterCapabilities(
+ func() bool { return s.capabilities.resources != nil },
+ func() { s.capabilities.resources = &resourceCapabilities{} },
+ )
+}
+
+func (s *MCPServer) implicitlyRegisterPromptCapabilities() {
+ s.implicitlyRegisterCapabilities(
+ func() bool { return s.capabilities.prompts != nil },
+ func() { s.capabilities.prompts = &promptCapabilities{} },
+ )
+}
+
+func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) {
+ s.capabilitiesMu.RLock()
+ if check() {
+ s.capabilitiesMu.RUnlock()
+ return
+ }
+ s.capabilitiesMu.RUnlock()
+
+ s.capabilitiesMu.Lock()
+ if !check() {
+ register()
+ }
+ s.capabilitiesMu.Unlock()
+}
+
+// AddTools registers multiple tools at once
+func (s *MCPServer) AddTools(tools ...ServerTool) {
+ s.implicitlyRegisterToolCapabilities()
+
+ s.toolsMu.Lock()
+ for _, entry := range tools {
+ s.tools[entry.Tool.Name] = entry
+ }
+ s.toolsMu.Unlock()
+
+ // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification.
+ if s.capabilities.tools.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil)
+ }
+}
+
+// SetTools replaces all existing tools with the provided list
+func (s *MCPServer) SetTools(tools ...ServerTool) {
+ s.toolsMu.Lock()
+ s.tools = make(map[string]ServerTool, len(tools))
+ s.toolsMu.Unlock()
+ s.AddTools(tools...)
+}
+
+// DeleteTools removes tools from the server
+func (s *MCPServer) DeleteTools(names ...string) {
+ s.toolsMu.Lock()
+ var exists bool
+ for _, name := range names {
+ if _, ok := s.tools[name]; ok {
+ delete(s.tools, name)
+ exists = true
+ }
+ }
+ s.toolsMu.Unlock()
+
+ // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification.
+ if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
+ // Send notification to all initialized sessions
+ s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil)
+ }
+}
+
+// AddNotificationHandler registers a new handler for incoming notifications
+func (s *MCPServer) AddNotificationHandler(
+ method string,
+ handler NotificationHandlerFunc,
+) {
+ s.notificationHandlersMu.Lock()
+ defer s.notificationHandlersMu.Unlock()
+ s.notificationHandlers[method] = handler
+}
+
+func (s *MCPServer) handleInitialize(
+ ctx context.Context,
+ _ any,
+ request mcp.InitializeRequest,
+) (*mcp.InitializeResult, *requestError) {
+ capabilities := mcp.ServerCapabilities{}
+
+ // Only add resource capabilities if they're configured
+ if s.capabilities.resources != nil {
+ capabilities.Resources = &struct {
+ Subscribe bool `json:"subscribe,omitempty"`
+ ListChanged bool `json:"listChanged,omitempty"`
+ }{
+ Subscribe: s.capabilities.resources.subscribe,
+ ListChanged: s.capabilities.resources.listChanged,
+ }
+ }
+
+ // Only add prompt capabilities if they're configured
+ if s.capabilities.prompts != nil {
+ capabilities.Prompts = &struct {
+ ListChanged bool `json:"listChanged,omitempty"`
+ }{
+ ListChanged: s.capabilities.prompts.listChanged,
+ }
+ }
+
+ // Only add tool capabilities if they're configured
+ if s.capabilities.tools != nil {
+ capabilities.Tools = &struct {
+ ListChanged bool `json:"listChanged,omitempty"`
+ }{
+ ListChanged: s.capabilities.tools.listChanged,
+ }
+ }
+
+ if s.capabilities.logging != nil && *s.capabilities.logging {
+ capabilities.Logging = &struct{}{}
+ }
+
+ result := mcp.InitializeResult{
+ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion),
+ ServerInfo: mcp.Implementation{
+ Name: s.name,
+ Version: s.version,
+ },
+ Capabilities: capabilities,
+ Instructions: s.instructions,
+ }
+
+ if session := ClientSessionFromContext(ctx); session != nil {
+ session.Initialize()
+
+ // Store client info if the session supports it
+ if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok {
+ sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo)
+ }
+ }
+ return &result, nil
+}
+
+func (s *MCPServer) protocolVersion(clientVersion string) string {
+ if slices.Contains(mcp.ValidProtocolVersions, clientVersion) {
+ return clientVersion
+ }
+
+ return mcp.LATEST_PROTOCOL_VERSION
+}
+
+func (s *MCPServer) handlePing(
+ _ context.Context,
+ _ any,
+ _ mcp.PingRequest,
+) (*mcp.EmptyResult, *requestError) {
+ return &mcp.EmptyResult{}, nil
+}
+
+func (s *MCPServer) handleSetLevel(
+ ctx context.Context,
+ id any,
+ request mcp.SetLevelRequest,
+) (*mcp.EmptyResult, *requestError) {
+ clientSession := ClientSessionFromContext(ctx)
+ if clientSession == nil || !clientSession.Initialized() {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: ErrSessionNotInitialized,
+ }
+ }
+
+ sessionLogging, ok := clientSession.(SessionWithLogging)
+ if !ok {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: ErrSessionDoesNotSupportLogging,
+ }
+ }
+
+ level := request.Params.Level
+ // Validate logging level
+ switch level {
+ case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice,
+ mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical,
+ mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency:
+ // Valid level
+ default:
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: fmt.Errorf("invalid logging level '%s'", level),
+ }
+ }
+
+ sessionLogging.SetLogLevel(level)
+
+ return &mcp.EmptyResult{}, nil
+}
+
+func listByPagination[T mcp.Named](
+ _ context.Context,
+ s *MCPServer,
+ cursor mcp.Cursor,
+ allElements []T,
+) ([]T, mcp.Cursor, error) {
+ startPos := 0
+ if cursor != "" {
+ c, err := base64.StdEncoding.DecodeString(string(cursor))
+ if err != nil {
+ return nil, "", err
+ }
+ cString := string(c)
+ startPos = sort.Search(len(allElements), func(i int) bool {
+ return allElements[i].GetName() > cString
+ })
+ }
+ endPos := len(allElements)
+ if s.paginationLimit != nil {
+ if len(allElements) > startPos+*s.paginationLimit {
+ endPos = startPos + *s.paginationLimit
+ }
+ }
+ elementsToReturn := allElements[startPos:endPos]
+ // set the next cursor
+ nextCursor := func() mcp.Cursor {
+ if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit {
+ nc := elementsToReturn[len(elementsToReturn)-1].GetName()
+ toString := base64.StdEncoding.EncodeToString([]byte(nc))
+ return mcp.Cursor(toString)
+ }
+ return ""
+ }()
+ return elementsToReturn, nextCursor, nil
+}
+
+func (s *MCPServer) handleListResources(
+ ctx context.Context,
+ id any,
+ request mcp.ListResourcesRequest,
+) (*mcp.ListResourcesResult, *requestError) {
+ s.resourcesMu.RLock()
+ resources := make([]mcp.Resource, 0, len(s.resources))
+ for _, entry := range s.resources {
+ resources = append(resources, entry.resource)
+ }
+ s.resourcesMu.RUnlock()
+
+ // Sort the resources by name
+ sort.Slice(resources, func(i, j int) bool {
+ return resources[i].Name < resources[j].Name
+ })
+ resourcesToReturn, nextCursor, err := listByPagination(
+ ctx,
+ s,
+ request.Params.Cursor,
+ resources,
+ )
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: err,
+ }
+ }
+ result := mcp.ListResourcesResult{
+ Resources: resourcesToReturn,
+ PaginatedResult: mcp.PaginatedResult{
+ NextCursor: nextCursor,
+ },
+ }
+ return &result, nil
+}
+
+func (s *MCPServer) handleListResourceTemplates(
+ ctx context.Context,
+ id any,
+ request mcp.ListResourceTemplatesRequest,
+) (*mcp.ListResourceTemplatesResult, *requestError) {
+ s.resourcesMu.RLock()
+ templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
+ for _, entry := range s.resourceTemplates {
+ templates = append(templates, entry.template)
+ }
+ s.resourcesMu.RUnlock()
+ sort.Slice(templates, func(i, j int) bool {
+ return templates[i].Name < templates[j].Name
+ })
+ templatesToReturn, nextCursor, err := listByPagination(
+ ctx,
+ s,
+ request.Params.Cursor,
+ templates,
+ )
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: err,
+ }
+ }
+ result := mcp.ListResourceTemplatesResult{
+ ResourceTemplates: templatesToReturn,
+ PaginatedResult: mcp.PaginatedResult{
+ NextCursor: nextCursor,
+ },
+ }
+ return &result, nil
+}
+
+func (s *MCPServer) handleReadResource(
+ ctx context.Context,
+ id any,
+ request mcp.ReadResourceRequest,
+) (*mcp.ReadResourceResult, *requestError) {
+ s.resourcesMu.RLock()
+ // First try direct resource handlers
+ if entry, ok := s.resources[request.Params.URI]; ok {
+ handler := entry.handler
+ s.resourcesMu.RUnlock()
+ contents, err := handler(ctx, request)
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: err,
+ }
+ }
+ return &mcp.ReadResourceResult{Contents: contents}, nil
+ }
+
+ // If no direct handler found, try matching against templates
+ var matchedHandler ResourceTemplateHandlerFunc
+ var matched bool
+ for _, entry := range s.resourceTemplates {
+ template := entry.template
+ if matchesTemplate(request.Params.URI, template.URITemplate) {
+ matchedHandler = entry.handler
+ matched = true
+ matchedVars := template.URITemplate.Match(request.Params.URI)
+ // Convert matched variables to a map
+ request.Params.Arguments = make(map[string]any, len(matchedVars))
+ for name, value := range matchedVars {
+ request.Params.Arguments[name] = value.V
+ }
+ break
+ }
+ }
+ s.resourcesMu.RUnlock()
+
+ if matched {
+ contents, err := matchedHandler(ctx, request)
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: err,
+ }
+ }
+ return &mcp.ReadResourceResult{Contents: contents}, nil
+ }
+
+ return nil, &requestError{
+ id: id,
+ code: mcp.RESOURCE_NOT_FOUND,
+ err: fmt.Errorf(
+ "handler not found for resource URI '%s': %w",
+ request.Params.URI,
+ ErrResourceNotFound,
+ ),
+ }
+}
+
+// matchesTemplate checks if a URI matches a URI template pattern
+func matchesTemplate(uri string, template *mcp.URITemplate) bool {
+ return template.Regexp().MatchString(uri)
+}
+
+func (s *MCPServer) handleListPrompts(
+ ctx context.Context,
+ id any,
+ request mcp.ListPromptsRequest,
+) (*mcp.ListPromptsResult, *requestError) {
+ s.promptsMu.RLock()
+ prompts := make([]mcp.Prompt, 0, len(s.prompts))
+ for _, prompt := range s.prompts {
+ prompts = append(prompts, prompt)
+ }
+ s.promptsMu.RUnlock()
+
+ // sort prompts by name
+ sort.Slice(prompts, func(i, j int) bool {
+ return prompts[i].Name < prompts[j].Name
+ })
+ promptsToReturn, nextCursor, err := listByPagination(
+ ctx,
+ s,
+ request.Params.Cursor,
+ prompts,
+ )
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: err,
+ }
+ }
+ result := mcp.ListPromptsResult{
+ Prompts: promptsToReturn,
+ PaginatedResult: mcp.PaginatedResult{
+ NextCursor: nextCursor,
+ },
+ }
+ return &result, nil
+}
+
+func (s *MCPServer) handleGetPrompt(
+ ctx context.Context,
+ id any,
+ request mcp.GetPromptRequest,
+) (*mcp.GetPromptResult, *requestError) {
+ s.promptsMu.RLock()
+ handler, ok := s.promptHandlers[request.Params.Name]
+ s.promptsMu.RUnlock()
+
+ if !ok {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound),
+ }
+ }
+
+ result, err := handler(ctx, request)
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: err,
+ }
+ }
+
+ return result, nil
+}
+
+func (s *MCPServer) handleListTools(
+ ctx context.Context,
+ id any,
+ request mcp.ListToolsRequest,
+) (*mcp.ListToolsResult, *requestError) {
+ // Get the base tools from the server
+ s.toolsMu.RLock()
+ tools := make([]mcp.Tool, 0, len(s.tools))
+
+ // Get all tool names for consistent ordering
+ toolNames := make([]string, 0, len(s.tools))
+ for name := range s.tools {
+ toolNames = append(toolNames, name)
+ }
+
+ // Sort the tool names for consistent ordering
+ sort.Strings(toolNames)
+
+ // Add tools in sorted order
+ for _, name := range toolNames {
+ tools = append(tools, s.tools[name].Tool)
+ }
+ s.toolsMu.RUnlock()
+
+ // Check if there are session-specific tools
+ session := ClientSessionFromContext(ctx)
+ if session != nil {
+ if sessionWithTools, ok := session.(SessionWithTools); ok {
+ if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil {
+ // Override or add session-specific tools
+ // We need to create a map first to merge the tools properly
+ toolMap := make(map[string]mcp.Tool)
+
+ // Add global tools first
+ for _, tool := range tools {
+ toolMap[tool.Name] = tool
+ }
+
+ // Then override with session-specific tools
+ for name, serverTool := range sessionTools {
+ toolMap[name] = serverTool.Tool
+ }
+
+ // Convert back to slice
+ tools = make([]mcp.Tool, 0, len(toolMap))
+ for _, tool := range toolMap {
+ tools = append(tools, tool)
+ }
+
+ // Sort again to maintain consistent ordering
+ sort.Slice(tools, func(i, j int) bool {
+ return tools[i].Name < tools[j].Name
+ })
+ }
+ }
+ }
+
+ // Apply tool filters if any are defined
+ s.toolFiltersMu.RLock()
+ if len(s.toolFilters) > 0 {
+ for _, filter := range s.toolFilters {
+ tools = filter(ctx, tools)
+ }
+ }
+ s.toolFiltersMu.RUnlock()
+
+ // Apply pagination
+ toolsToReturn, nextCursor, err := listByPagination(
+ ctx,
+ s,
+ request.Params.Cursor,
+ tools,
+ )
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: err,
+ }
+ }
+
+ result := mcp.ListToolsResult{
+ Tools: toolsToReturn,
+ PaginatedResult: mcp.PaginatedResult{
+ NextCursor: nextCursor,
+ },
+ }
+ return &result, nil
+}
+
+func (s *MCPServer) handleToolCall(
+ ctx context.Context,
+ id any,
+ request mcp.CallToolRequest,
+) (*mcp.CallToolResult, *requestError) {
+ // First check session-specific tools
+ var tool ServerTool
+ var ok bool
+
+ session := ClientSessionFromContext(ctx)
+ if session != nil {
+ if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk {
+ if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil {
+ var sessionOk bool
+ tool, sessionOk = sessionTools[request.Params.Name]
+ if sessionOk {
+ ok = true
+ }
+ }
+ }
+ }
+
+ // If not found in session tools, check global tools
+ if !ok {
+ s.toolsMu.RLock()
+ tool, ok = s.tools[request.Params.Name]
+ s.toolsMu.RUnlock()
+ }
+
+ if !ok {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INVALID_PARAMS,
+ err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound),
+ }
+ }
+
+ finalHandler := tool.Handler
+
+ s.middlewareMu.RLock()
+ mw := s.toolHandlerMiddlewares
+ s.middlewareMu.RUnlock()
+
+ // Apply middlewares in reverse order
+ for i := len(mw) - 1; i >= 0; i-- {
+ finalHandler = mw[i](finalHandler)
+ }
+
+ result, err := finalHandler(ctx, request)
+ if err != nil {
+ return nil, &requestError{
+ id: id,
+ code: mcp.INTERNAL_ERROR,
+ err: err,
+ }
+ }
+
+ return result, nil
+}
+
+func (s *MCPServer) handleNotification(
+ ctx context.Context,
+ notification mcp.JSONRPCNotification,
+) mcp.JSONRPCMessage {
+ s.notificationHandlersMu.RLock()
+ handler, ok := s.notificationHandlers[notification.Method]
+ s.notificationHandlersMu.RUnlock()
+
+ if ok {
+ handler(ctx, notification)
+ }
+ return nil
+}
+
+func createResponse(id any, result any) mcp.JSONRPCMessage {
+ return mcp.JSONRPCResponse{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: mcp.NewRequestId(id),
+ Result: result,
+ }
+}
+
+func createErrorResponse(
+ id any,
+ code int,
+ message string,
+) mcp.JSONRPCMessage {
+ return mcp.JSONRPCError{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ ID: mcp.NewRequestId(id),
+ Error: struct {
+ Code int `json:"code"`
+ Message string `json:"message"`
+ Data any `json:"data,omitempty"`
+ }{
+ Code: code,
+ Message: message,
+ },
+ }
+}
@@ -0,0 +1,380 @@
+package server
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// ClientSession represents an active session that can be used by MCPServer to interact with client.
+type ClientSession interface {
+ // Initialize marks session as fully initialized and ready for notifications
+ Initialize()
+ // Initialized returns if session is ready to accept notifications
+ Initialized() bool
+ // NotificationChannel provides a channel suitable for sending notifications to client.
+ NotificationChannel() chan<- mcp.JSONRPCNotification
+ // SessionID is a unique identifier used to track user session.
+ SessionID() string
+}
+
+// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
+type SessionWithLogging interface {
+ ClientSession
+ // SetLogLevel sets the minimum log level
+ SetLogLevel(level mcp.LoggingLevel)
+ // GetLogLevel retrieves the minimum log level
+ GetLogLevel() mcp.LoggingLevel
+}
+
+// SessionWithTools is an extension of ClientSession that can store session-specific tool data
+type SessionWithTools interface {
+ ClientSession
+ // GetSessionTools returns the tools specific to this session, if any
+ // This method must be thread-safe for concurrent access
+ GetSessionTools() map[string]ServerTool
+ // SetSessionTools sets tools specific to this session
+ // This method must be thread-safe for concurrent access
+ SetSessionTools(tools map[string]ServerTool)
+}
+
+// SessionWithClientInfo is an extension of ClientSession that can store client info
+type SessionWithClientInfo interface {
+ ClientSession
+ // GetClientInfo returns the client information for this session
+ GetClientInfo() mcp.Implementation
+ // SetClientInfo sets the client information for this session
+ SetClientInfo(clientInfo mcp.Implementation)
+}
+
+// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
+type SessionWithStreamableHTTPConfig interface {
+ ClientSession
+ // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
+ // sends notifications to the client
+ //
+ // The protocol specification:
+ // - If the server response contains any JSON-RPC notifications, it MUST either:
+ // - Return Content-Type: text/event-stream to initiate an SSE stream, OR
+ // - Return Content-Type: application/json for a single JSON object
+ // - The client MUST support both response types.
+ //
+ // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
+ UpgradeToSSEWhenReceiveNotification()
+}
+
+// clientSessionKey is the context key for storing current client notification channel.
+type clientSessionKey struct{}
+
+// ClientSessionFromContext retrieves current client notification context from context.
+func ClientSessionFromContext(ctx context.Context) ClientSession {
+ if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
+ return session
+ }
+ return nil
+}
+
+// WithContext sets the current client session and returns the provided context
+func (s *MCPServer) WithContext(
+ ctx context.Context,
+ session ClientSession,
+) context.Context {
+ return context.WithValue(ctx, clientSessionKey{}, session)
+}
+
+// RegisterSession saves session that should be notified in case if some server attributes changed.
+func (s *MCPServer) RegisterSession(
+ ctx context.Context,
+ session ClientSession,
+) error {
+ sessionID := session.SessionID()
+ if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
+ return ErrSessionExists
+ }
+ s.hooks.RegisterSession(ctx, session)
+ return nil
+}
+
+// UnregisterSession removes from storage session that is shut down.
+func (s *MCPServer) UnregisterSession(
+ ctx context.Context,
+ sessionID string,
+) {
+ sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
+ if !ok {
+ return
+ }
+ if session, ok := sessionValue.(ClientSession); ok {
+ s.hooks.UnregisterSession(ctx, session)
+ }
+}
+
+// SendNotificationToAllClients sends a notification to all the currently active clients.
+func (s *MCPServer) SendNotificationToAllClients(
+ method string,
+ params map[string]any,
+) {
+ notification := mcp.JSONRPCNotification{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ Notification: mcp.Notification{
+ Method: method,
+ Params: mcp.NotificationParams{
+ AdditionalFields: params,
+ },
+ },
+ }
+
+ s.sessions.Range(func(k, v any) bool {
+ if session, ok := v.(ClientSession); ok && session.Initialized() {
+ select {
+ case session.NotificationChannel() <- notification:
+ // Successfully sent notification
+ default:
+ // Channel is blocked, if there's an error hook, use it
+ if s.hooks != nil && len(s.hooks.OnError) > 0 {
+ err := ErrNotificationChannelBlocked
+ // Copy hooks pointer to local variable to avoid race condition
+ hooks := s.hooks
+ go func(sessionID string, hooks *Hooks) {
+ ctx := context.Background()
+ // Use the error hook to report the blocked channel
+ hooks.onError(ctx, nil, "notification", map[string]any{
+ "method": method,
+ "sessionID": sessionID,
+ }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
+ }(session.SessionID(), hooks)
+ }
+ }
+ }
+ return true
+ })
+}
+
+// SendNotificationToClient sends a notification to the current client
+func (s *MCPServer) SendNotificationToClient(
+ ctx context.Context,
+ method string,
+ params map[string]any,
+) error {
+ session := ClientSessionFromContext(ctx)
+ if session == nil || !session.Initialized() {
+ return ErrNotificationNotInitialized
+ }
+
+ // upgrades the client-server communication to SSE stream when the server sends notifications to the client
+ if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
+ sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
+ }
+
+ notification := mcp.JSONRPCNotification{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ Notification: mcp.Notification{
+ Method: method,
+ Params: mcp.NotificationParams{
+ AdditionalFields: params,
+ },
+ },
+ }
+
+ select {
+ case session.NotificationChannel() <- notification:
+ return nil
+ default:
+ // Channel is blocked, if there's an error hook, use it
+ if s.hooks != nil && len(s.hooks.OnError) > 0 {
+ err := ErrNotificationChannelBlocked
+ // Copy hooks pointer to local variable to avoid race condition
+ hooks := s.hooks
+ go func(sessionID string, hooks *Hooks) {
+ // Use the error hook to report the blocked channel
+ hooks.onError(ctx, nil, "notification", map[string]any{
+ "method": method,
+ "sessionID": sessionID,
+ }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
+ }(session.SessionID(), hooks)
+ }
+ return ErrNotificationChannelBlocked
+ }
+}
+
+// SendNotificationToSpecificClient sends a notification to a specific client by session ID
+func (s *MCPServer) SendNotificationToSpecificClient(
+ sessionID string,
+ method string,
+ params map[string]any,
+) error {
+ sessionValue, ok := s.sessions.Load(sessionID)
+ if !ok {
+ return ErrSessionNotFound
+ }
+
+ session, ok := sessionValue.(ClientSession)
+ if !ok || !session.Initialized() {
+ return ErrSessionNotInitialized
+ }
+
+ // upgrades the client-server communication to SSE stream when the server sends notifications to the client
+ if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
+ sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
+ }
+
+ notification := mcp.JSONRPCNotification{
+ JSONRPC: mcp.JSONRPC_VERSION,
+ Notification: mcp.Notification{
+ Method: method,
+ Params: mcp.NotificationParams{
+ AdditionalFields: params,
+ },
+ },
+ }
+
+ select {
+ case session.NotificationChannel() <- notification:
+ return nil
+ default:
+ // Channel is blocked, if there's an error hook, use it
+ if s.hooks != nil && len(s.hooks.OnError) > 0 {
+ err := ErrNotificationChannelBlocked
+ ctx := context.Background()
+ // Copy hooks pointer to local variable to avoid race condition
+ hooks := s.hooks
+ go func(sID string, hooks *Hooks) {
+ // Use the error hook to report the blocked channel
+ hooks.onError(ctx, nil, "notification", map[string]any{
+ "method": method,
+ "sessionID": sID,
+ }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
+ }(sessionID, hooks)
+ }
+ return ErrNotificationChannelBlocked
+ }
+}
+
+// AddSessionTool adds a tool for a specific session
+func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error {
+ return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler})
+}
+
+// AddSessionTools adds tools for a specific session
+func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error {
+ sessionValue, ok := s.sessions.Load(sessionID)
+ if !ok {
+ return ErrSessionNotFound
+ }
+
+ session, ok := sessionValue.(SessionWithTools)
+ if !ok {
+ return ErrSessionDoesNotSupportTools
+ }
+
+ s.implicitlyRegisterToolCapabilities()
+
+ // Get existing tools (this should return a thread-safe copy)
+ sessionTools := session.GetSessionTools()
+
+ // Create a new map to avoid concurrent modification issues
+ newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools))
+
+ // Copy existing tools
+ for k, v := range sessionTools {
+ newSessionTools[k] = v
+ }
+
+ // Add new tools
+ for _, tool := range tools {
+ newSessionTools[tool.Tool.Name] = tool
+ }
+
+ // Set the tools (this should be thread-safe)
+ session.SetSessionTools(newSessionTools)
+
+ // It only makes sense to send tool notifications to initialized sessions --
+ // if we're not initialized yet the client can't possibly have sent their
+ // initial tools/list message.
+ //
+ // For initialized sessions, honor tools.listChanged, which is specifically
+ // about whether notifications will be sent or not.
+ // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
+ if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
+ // Send notification only to this session
+ if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
+ // Log the error but don't fail the operation
+ // The tools were successfully added, but notification failed
+ if s.hooks != nil && len(s.hooks.OnError) > 0 {
+ hooks := s.hooks
+ go func(sID string, hooks *Hooks) {
+ ctx := context.Background()
+ hooks.onError(ctx, nil, "notification", map[string]any{
+ "method": "notifications/tools/list_changed",
+ "sessionID": sID,
+ }, fmt.Errorf("failed to send notification after adding tools: %w", err))
+ }(sessionID, hooks)
+ }
+ }
+ }
+
+ return nil
+}
+
+// DeleteSessionTools removes tools from a specific session
+func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error {
+ sessionValue, ok := s.sessions.Load(sessionID)
+ if !ok {
+ return ErrSessionNotFound
+ }
+
+ session, ok := sessionValue.(SessionWithTools)
+ if !ok {
+ return ErrSessionDoesNotSupportTools
+ }
+
+ // Get existing tools (this should return a thread-safe copy)
+ sessionTools := session.GetSessionTools()
+ if sessionTools == nil {
+ return nil
+ }
+
+ // Create a new map to avoid concurrent modification issues
+ newSessionTools := make(map[string]ServerTool, len(sessionTools))
+
+ // Copy existing tools except those being deleted
+ for k, v := range sessionTools {
+ newSessionTools[k] = v
+ }
+
+ // Remove specified tools
+ for _, name := range names {
+ delete(newSessionTools, name)
+ }
+
+ // Set the tools (this should be thread-safe)
+ session.SetSessionTools(newSessionTools)
+
+ // It only makes sense to send tool notifications to initialized sessions --
+ // if we're not initialized yet the client can't possibly have sent their
+ // initial tools/list message.
+ //
+ // For initialized sessions, honor tools.listChanged, which is specifically
+ // about whether notifications will be sent or not.
+ // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
+ if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
+ // Send notification only to this session
+ if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
+ // Log the error but don't fail the operation
+ // The tools were successfully deleted, but notification failed
+ if s.hooks != nil && len(s.hooks.OnError) > 0 {
+ hooks := s.hooks
+ go func(sID string, hooks *Hooks) {
+ ctx := context.Background()
+ hooks.onError(ctx, nil, "notification", map[string]any{
+ "method": "notifications/tools/list_changed",
+ "sessionID": sID,
+ }, fmt.Errorf("failed to send notification after deleting tools: %w", err))
+ }(sessionID, hooks)
+ }
+ }
+ }
+
+ return nil
+}
@@ -0,0 +1,736 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "path"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// sseSession represents an active SSE connection.
+type sseSession struct {
+ done chan struct{}
+ eventQueue chan string // Channel for queuing events
+ sessionID string
+ requestID atomic.Int64
+ notificationChannel chan mcp.JSONRPCNotification
+ initialized atomic.Bool
+ loggingLevel atomic.Value
+ tools sync.Map // stores session-specific tools
+ clientInfo atomic.Value // stores session-specific client info
+}
+
+// SSEContextFunc is a function that takes an existing context and the current
+// request and returns a potentially modified context based on the request
+// content. This can be used to inject context values from headers, for example.
+type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
+
+// DynamicBasePathFunc allows the user to provide a function to generate the
+// base path for a given request and sessionID. This is useful for cases where
+// the base path is not known at the time of SSE server creation, such as when
+// using a reverse proxy or when the base path is dynamically generated. The
+// function should return the base path (e.g., "/mcp/tenant123").
+type DynamicBasePathFunc func(r *http.Request, sessionID string) string
+
+func (s *sseSession) SessionID() string {
+ return s.sessionID
+}
+
+func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
+ return s.notificationChannel
+}
+
+func (s *sseSession) Initialize() {
+ // set default logging level
+ s.loggingLevel.Store(mcp.LoggingLevelError)
+ s.initialized.Store(true)
+}
+
+func (s *sseSession) Initialized() bool {
+ return s.initialized.Load()
+}
+
+func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) {
+ s.loggingLevel.Store(level)
+}
+
+func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
+ level := s.loggingLevel.Load()
+ if level == nil {
+ return mcp.LoggingLevelError
+ }
+ return level.(mcp.LoggingLevel)
+}
+
+func (s *sseSession) GetSessionTools() map[string]ServerTool {
+ tools := make(map[string]ServerTool)
+ s.tools.Range(func(key, value any) bool {
+ if tool, ok := value.(ServerTool); ok {
+ tools[key.(string)] = tool
+ }
+ return true
+ })
+ return tools
+}
+
+func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
+ // Clear existing tools
+ s.tools.Clear()
+
+ // Set new tools
+ for name, tool := range tools {
+ s.tools.Store(name, tool)
+ }
+}
+
+func (s *sseSession) GetClientInfo() mcp.Implementation {
+ if value := s.clientInfo.Load(); value != nil {
+ if clientInfo, ok := value.(mcp.Implementation); ok {
+ return clientInfo
+ }
+ }
+ return mcp.Implementation{}
+}
+
+func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
+ s.clientInfo.Store(clientInfo)
+}
+
+var (
+ _ ClientSession = (*sseSession)(nil)
+ _ SessionWithTools = (*sseSession)(nil)
+ _ SessionWithLogging = (*sseSession)(nil)
+ _ SessionWithClientInfo = (*sseSession)(nil)
+)
+
+// SSEServer implements a Server-Sent Events (SSE) based MCP server.
+// It provides real-time communication capabilities over HTTP using the SSE protocol.
+type SSEServer struct {
+ server *MCPServer
+ baseURL string
+ basePath string
+ appendQueryToMessageEndpoint bool
+ useFullURLForMessageEndpoint bool
+ messageEndpoint string
+ sseEndpoint string
+ sessions sync.Map
+ srv *http.Server
+ contextFunc SSEContextFunc
+ dynamicBasePathFunc DynamicBasePathFunc
+
+ keepAlive bool
+ keepAliveInterval time.Duration
+
+ mu sync.RWMutex
+}
+
+// SSEOption defines a function type for configuring SSEServer
+type SSEOption func(*SSEServer)
+
+// WithBaseURL sets the base URL for the SSE server
+func WithBaseURL(baseURL string) SSEOption {
+ return func(s *SSEServer) {
+ if baseURL != "" {
+ u, err := url.Parse(baseURL)
+ if err != nil {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ // Check if the host is empty or only contains a port
+ if u.Host == "" || strings.HasPrefix(u.Host, ":") {
+ return
+ }
+ if len(u.Query()) > 0 {
+ return
+ }
+ }
+ s.baseURL = strings.TrimSuffix(baseURL, "/")
+ }
+}
+
+// WithStaticBasePath adds a new option for setting a static base path
+func WithStaticBasePath(basePath string) SSEOption {
+ return func(s *SSEServer) {
+ s.basePath = normalizeURLPath(basePath)
+ }
+}
+
+// WithBasePath adds a new option for setting a static base path.
+//
+// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version.
+//
+//go:deprecated
+func WithBasePath(basePath string) SSEOption {
+ return WithStaticBasePath(basePath)
+}
+
+// WithDynamicBasePath accepts a function for generating the base path. This is
+// useful for cases where the base path is not known at the time of SSE server
+// creation, such as when using a reverse proxy or when the server is mounted
+// at a dynamic path.
+func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption {
+ return func(s *SSEServer) {
+ if fn != nil {
+ s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
+ bp := fn(r, sid)
+ return normalizeURLPath(bp)
+ }
+ }
+ }
+}
+
+// WithMessageEndpoint sets the message endpoint path
+func WithMessageEndpoint(endpoint string) SSEOption {
+ return func(s *SSEServer) {
+ s.messageEndpoint = endpoint
+ }
+}
+
+// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's
+// query parameters to the message endpoint URL that is sent to clients during the SSE connection
+// initialization. This is useful when you need to preserve query parameters from the initial
+// SSE connection request and carry them over to subsequent message requests, maintaining
+// context or authentication details across the communication channel.
+func WithAppendQueryToMessageEndpoint() SSEOption {
+ return func(s *SSEServer) {
+ s.appendQueryToMessageEndpoint = true
+ }
+}
+
+// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
+// or just the path portion for the message endpoint. Set to false when clients will concatenate
+// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
+func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption {
+ return func(s *SSEServer) {
+ s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
+ }
+}
+
+// WithSSEEndpoint sets the SSE endpoint path
+func WithSSEEndpoint(endpoint string) SSEOption {
+ return func(s *SSEServer) {
+ s.sseEndpoint = endpoint
+ }
+}
+
+// WithHTTPServer sets the HTTP server instance.
+// NOTE: When providing a custom HTTP server, you must handle routing yourself
+// If routing is not set up, the server will start but won't handle any MCP requests.
+func WithHTTPServer(srv *http.Server) SSEOption {
+ return func(s *SSEServer) {
+ s.srv = srv
+ }
+}
+
+func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
+ return func(s *SSEServer) {
+ s.keepAlive = true
+ s.keepAliveInterval = keepAliveInterval
+ }
+}
+
+func WithKeepAlive(keepAlive bool) SSEOption {
+ return func(s *SSEServer) {
+ s.keepAlive = keepAlive
+ }
+}
+
+// WithSSEContextFunc sets a function that will be called to customise the context
+// to the server using the incoming request.
+func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
+ return func(s *SSEServer) {
+ s.contextFunc = fn
+ }
+}
+
+// NewSSEServer creates a new SSE server instance with the given MCP server and options.
+func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
+ s := &SSEServer{
+ server: server,
+ sseEndpoint: "/sse",
+ messageEndpoint: "/message",
+ useFullURLForMessageEndpoint: true,
+ keepAlive: false,
+ keepAliveInterval: 10 * time.Second,
+ }
+
+ // Apply all options
+ for _, opt := range opts {
+ opt(s)
+ }
+
+ return s
+}
+
+// NewTestServer creates a test server for testing purposes
+func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
+ sseServer := NewSSEServer(server, opts...)
+
+ testServer := httptest.NewServer(sseServer)
+ sseServer.baseURL = testServer.URL
+ return testServer
+}
+
+// Start begins serving SSE connections on the specified address.
+// It sets up HTTP handlers for SSE and message endpoints.
+func (s *SSEServer) Start(addr string) error {
+ s.mu.Lock()
+ if s.srv == nil {
+ s.srv = &http.Server{
+ Addr: addr,
+ Handler: s,
+ }
+ } else {
+ if s.srv.Addr == "" {
+ s.srv.Addr = addr
+ } else if s.srv.Addr != addr {
+ return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr)
+ }
+ }
+ srv := s.srv
+ s.mu.Unlock()
+
+ return srv.ListenAndServe()
+}
+
+// Shutdown gracefully stops the SSE server, closing all active sessions
+// and shutting down the HTTP server.
+func (s *SSEServer) Shutdown(ctx context.Context) error {
+ s.mu.RLock()
+ srv := s.srv
+ s.mu.RUnlock()
+
+ if srv != nil {
+ s.sessions.Range(func(key, value any) bool {
+ if session, ok := value.(*sseSession); ok {
+ close(session.done)
+ }
+ s.sessions.Delete(key)
+ return true
+ })
+
+ return srv.Shutdown(ctx)
+ }
+ return nil
+}
+
+// handleSSE handles incoming SSE connection requests.
+// It sets up appropriate headers and creates a new session for the client.
+func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodGet {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("Access-Control-Allow-Origin", "*")
+
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
+ return
+ }
+
+ sessionID := uuid.New().String()
+ session := &sseSession{
+ done: make(chan struct{}),
+ eventQueue: make(chan string, 100), // Buffer for events
+ sessionID: sessionID,
+ notificationChannel: make(chan mcp.JSONRPCNotification, 100),
+ }
+
+ s.sessions.Store(sessionID, session)
+ defer s.sessions.Delete(sessionID)
+
+ if err := s.server.RegisterSession(r.Context(), session); err != nil {
+ http.Error(
+ w,
+ fmt.Sprintf("Session registration failed: %v", err),
+ http.StatusInternalServerError,
+ )
+ return
+ }
+ defer s.server.UnregisterSession(r.Context(), sessionID)
+
+ // Start notification handler for this session
+ go func() {
+ for {
+ select {
+ case notification := <-session.notificationChannel:
+ eventData, err := json.Marshal(notification)
+ if err == nil {
+ select {
+ case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
+ // Event queued successfully
+ case <-session.done:
+ return
+ }
+ }
+ case <-session.done:
+ return
+ case <-r.Context().Done():
+ return
+ }
+ }
+ }()
+
+ // Start keep alive : ping
+ if s.keepAlive {
+ go func() {
+ ticker := time.NewTicker(s.keepAliveInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ message := mcp.JSONRPCRequest{
+ JSONRPC: "2.0",
+ ID: mcp.NewRequestId(session.requestID.Add(1)),
+ Request: mcp.Request{
+ Method: "ping",
+ },
+ }
+ messageBytes, _ := json.Marshal(message)
+ pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
+ select {
+ case session.eventQueue <- pingMsg:
+ // Message sent successfully
+ case <-session.done:
+ return
+ }
+ case <-session.done:
+ return
+ case <-r.Context().Done():
+ return
+ }
+ }
+ }()
+ }
+
+ // Send the initial endpoint event
+ endpoint := s.GetMessageEndpointForClient(r, sessionID)
+ if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 {
+ endpoint += "&" + r.URL.RawQuery
+ }
+ fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint)
+ flusher.Flush()
+
+ // Main event loop - this runs in the HTTP handler goroutine
+ for {
+ select {
+ case event := <-session.eventQueue:
+ // Write the event to the response
+ fmt.Fprint(w, event)
+ flusher.Flush()
+ case <-r.Context().Done():
+ close(session.done)
+ return
+ case <-session.done:
+ return
+ }
+ }
+}
+
+// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
+// for the given request. This is the canonical way to compute the message endpoint for a client.
+// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag.
+func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string {
+ basePath := s.basePath
+ if s.dynamicBasePathFunc != nil {
+ basePath = s.dynamicBasePathFunc(r, sessionID)
+ }
+
+ endpointPath := normalizeURLPath(basePath, s.messageEndpoint)
+ if s.useFullURLForMessageEndpoint && s.baseURL != "" {
+ endpointPath = s.baseURL + endpointPath
+ }
+
+ return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID)
+}
+
+// handleMessage processes incoming JSON-RPC messages from clients and sends responses
+// back through the SSE connection and 202 code to HTTP response.
+func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
+ return
+ }
+
+ sessionID := r.URL.Query().Get("sessionId")
+ if sessionID == "" {
+ s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
+ return
+ }
+ sessionI, ok := s.sessions.Load(sessionID)
+ if !ok {
+ s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
+ return
+ }
+ session := sessionI.(*sseSession)
+
+ // Set the client context before handling the message
+ ctx := s.server.WithContext(r.Context(), session)
+ if s.contextFunc != nil {
+ ctx = s.contextFunc(ctx, r)
+ }
+
+ // Parse message as raw JSON
+ var rawMessage json.RawMessage
+ if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
+ s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
+ return
+ }
+
+ // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled.
+ // this is required because the http ctx will be canceled when the client disconnects
+ detachedCtx := context.WithoutCancel(ctx)
+
+ // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
+ w.WriteHeader(http.StatusAccepted)
+
+ // Create a new context for handling the message that will be canceled when the message handling is done
+ messageCtx, cancel := context.WithCancel(detachedCtx)
+
+ go func(ctx context.Context) {
+ defer cancel()
+ // Use the context that will be canceled when session is done
+ // Process message through MCPServer
+ response := s.server.HandleMessage(ctx, rawMessage)
+ // Only send response if there is one (not for notifications)
+ if response != nil {
+ var message string
+ if eventData, err := json.Marshal(response); err != nil {
+ // If there is an error marshalling the response, send a generic error response
+ log.Printf("failed to marshal response: %v", err)
+ message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
+ } else {
+ message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
+ }
+
+ // Queue the event for sending via SSE
+ select {
+ case session.eventQueue <- message:
+ // Event queued successfully
+ case <-session.done:
+ // Session is closed, don't try to queue
+ default:
+ // Queue is full, log this situation
+ log.Printf("Event queue full for session %s", sessionID)
+ }
+ }
+ }(messageCtx)
+}
+
+// writeJSONRPCError writes a JSON-RPC error response with the given error details.
+func (s *SSEServer) writeJSONRPCError(
+ w http.ResponseWriter,
+ id any,
+ code int,
+ message string,
+) {
+ response := createErrorResponse(id, code, message)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ if err := json.NewEncoder(w).Encode(response); err != nil {
+ http.Error(
+ w,
+ fmt.Sprintf("Failed to encode response: %v", err),
+ http.StatusInternalServerError,
+ )
+ return
+ }
+}
+
+// SendEventToSession sends an event to a specific SSE session identified by sessionID.
+// Returns an error if the session is not found or closed.
+func (s *SSEServer) SendEventToSession(
+ sessionID string,
+ event any,
+) error {
+ sessionI, ok := s.sessions.Load(sessionID)
+ if !ok {
+ return fmt.Errorf("session not found: %s", sessionID)
+ }
+ session := sessionI.(*sseSession)
+
+ eventData, err := json.Marshal(event)
+ if err != nil {
+ return err
+ }
+
+ // Queue the event for sending via SSE
+ select {
+ case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
+ return nil
+ case <-session.done:
+ return fmt.Errorf("session closed")
+ default:
+ return fmt.Errorf("event queue full")
+ }
+}
+
+func (s *SSEServer) GetUrlPath(input string) (string, error) {
+ parse, err := url.Parse(input)
+ if err != nil {
+ return "", fmt.Errorf("failed to parse URL %s: %w", input, err)
+ }
+ return parse.Path, nil
+}
+
+func (s *SSEServer) CompleteSseEndpoint() (string, error) {
+ if s.dynamicBasePathFunc != nil {
+ return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"}
+ }
+
+ path := normalizeURLPath(s.basePath, s.sseEndpoint)
+ return s.baseURL + path, nil
+}
+
+func (s *SSEServer) CompleteSsePath() string {
+ path, err := s.CompleteSseEndpoint()
+ if err != nil {
+ return normalizeURLPath(s.basePath, s.sseEndpoint)
+ }
+ urlPath, err := s.GetUrlPath(path)
+ if err != nil {
+ return normalizeURLPath(s.basePath, s.sseEndpoint)
+ }
+ return urlPath
+}
+
+func (s *SSEServer) CompleteMessageEndpoint() (string, error) {
+ if s.dynamicBasePathFunc != nil {
+ return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"}
+ }
+ path := normalizeURLPath(s.basePath, s.messageEndpoint)
+ return s.baseURL + path, nil
+}
+
+func (s *SSEServer) CompleteMessagePath() string {
+ path, err := s.CompleteMessageEndpoint()
+ if err != nil {
+ return normalizeURLPath(s.basePath, s.messageEndpoint)
+ }
+ urlPath, err := s.GetUrlPath(path)
+ if err != nil {
+ return normalizeURLPath(s.basePath, s.messageEndpoint)
+ }
+ return urlPath
+}
+
+// SSEHandler returns an http.Handler for the SSE endpoint.
+//
+// This method allows you to mount the SSE handler at any arbitrary path
+// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
+// intended for advanced scenarios where you want to control the routing or
+// support dynamic segments.
+//
+// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
+// you must use the WithDynamicBasePath option to ensure the correct base path
+// is communicated to clients.
+//
+// Example usage:
+//
+// // Advanced/dynamic:
+// sseServer := NewSSEServer(mcpServer,
+// WithDynamicBasePath(func(r *http.Request, sessionID string) string {
+// tenant := r.PathValue("tenant")
+// return "/mcp/" + tenant
+// }),
+// WithBaseURL("http://localhost:8080")
+// )
+// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
+// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
+//
+// For non-dynamic cases, use ServeHTTP method instead.
+func (s *SSEServer) SSEHandler() http.Handler {
+ return http.HandlerFunc(s.handleSSE)
+}
+
+// MessageHandler returns an http.Handler for the message endpoint.
+//
+// This method allows you to mount the message handler at any arbitrary path
+// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
+// intended for advanced scenarios where you want to control the routing or
+// support dynamic segments.
+//
+// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
+// you must use the WithDynamicBasePath option to ensure the correct base path
+// is communicated to clients.
+//
+// Example usage:
+//
+// // Advanced/dynamic:
+// sseServer := NewSSEServer(mcpServer,
+// WithDynamicBasePath(func(r *http.Request, sessionID string) string {
+// tenant := r.PathValue("tenant")
+// return "/mcp/" + tenant
+// }),
+// WithBaseURL("http://localhost:8080")
+// )
+// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
+// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
+//
+// For non-dynamic cases, use ServeHTTP method instead.
+func (s *SSEServer) MessageHandler() http.Handler {
+ return http.HandlerFunc(s.handleMessage)
+}
+
+// ServeHTTP implements the http.Handler interface.
+func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ if s.dynamicBasePathFunc != nil {
+ http.Error(
+ w,
+ (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(),
+ http.StatusInternalServerError,
+ )
+ return
+ }
+ path := r.URL.Path
+ // Use exact path matching rather than Contains
+ ssePath := s.CompleteSsePath()
+ if ssePath != "" && path == ssePath {
+ s.handleSSE(w, r)
+ return
+ }
+ messagePath := s.CompleteMessagePath()
+ if messagePath != "" && path == messagePath {
+ s.handleMessage(w, r)
+ return
+ }
+
+ http.NotFound(w, r)
+}
+
+// normalizeURLPath joins path elements like path.Join but ensures the
+// result always starts with a leading slash and never ends with a slash
+func normalizeURLPath(elem ...string) string {
+ joined := path.Join(elem...)
+
+ // Ensure leading slash
+ if !strings.HasPrefix(joined, "/") {
+ joined = "/" + joined
+ }
+
+ // Remove trailing slash if not just "/"
+ if len(joined) > 1 && strings.HasSuffix(joined, "/") {
+ joined = joined[:len(joined)-1]
+ }
+
+ return joined
+}
@@ -0,0 +1,314 @@
+package server
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "os/signal"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+// StdioContextFunc is a function that takes an existing context and returns
+// a potentially modified context.
+// This can be used to inject context values from environment variables,
+// for example.
+type StdioContextFunc func(ctx context.Context) context.Context
+
+// StdioServer wraps a MCPServer and handles stdio communication.
+// It provides a simple way to create command-line MCP servers that
+// communicate via standard input/output streams using JSON-RPC messages.
+type StdioServer struct {
+ server *MCPServer
+ errLogger *log.Logger
+ contextFunc StdioContextFunc
+}
+
+// StdioOption defines a function type for configuring StdioServer
+type StdioOption func(*StdioServer)
+
+// WithErrorLogger sets the error logger for the server
+func WithErrorLogger(logger *log.Logger) StdioOption {
+ return func(s *StdioServer) {
+ s.errLogger = logger
+ }
+}
+
+// WithStdioContextFunc sets a function that will be called to customise the context
+// to the server. Note that the stdio server uses the same context for all requests,
+// so this function will only be called once per server instance.
+func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
+ return func(s *StdioServer) {
+ s.contextFunc = fn
+ }
+}
+
+// stdioSession is a static client session, since stdio has only one client.
+type stdioSession struct {
+ notifications chan mcp.JSONRPCNotification
+ initialized atomic.Bool
+ loggingLevel atomic.Value
+ clientInfo atomic.Value // stores session-specific client info
+}
+
+func (s *stdioSession) SessionID() string {
+ return "stdio"
+}
+
+func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
+ return s.notifications
+}
+
+func (s *stdioSession) Initialize() {
+ // set default logging level
+ s.loggingLevel.Store(mcp.LoggingLevelError)
+ s.initialized.Store(true)
+}
+
+func (s *stdioSession) Initialized() bool {
+ return s.initialized.Load()
+}
+
+func (s *stdioSession) GetClientInfo() mcp.Implementation {
+ if value := s.clientInfo.Load(); value != nil {
+ if clientInfo, ok := value.(mcp.Implementation); ok {
+ return clientInfo
+ }
+ }
+ return mcp.Implementation{}
+}
+
+func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
+ s.clientInfo.Store(clientInfo)
+}
+
+func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
+ s.loggingLevel.Store(level)
+}
+
+func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
+ level := s.loggingLevel.Load()
+ if level == nil {
+ return mcp.LoggingLevelError
+ }
+ return level.(mcp.LoggingLevel)
+}
+
+var (
+ _ ClientSession = (*stdioSession)(nil)
+ _ SessionWithLogging = (*stdioSession)(nil)
+ _ SessionWithClientInfo = (*stdioSession)(nil)
+)
+
+var stdioSessionInstance = stdioSession{
+ notifications: make(chan mcp.JSONRPCNotification, 100),
+}
+
+// NewStdioServer creates a new stdio server wrapper around an MCPServer.
+// It initializes the server with a default error logger that discards all output.
+func NewStdioServer(server *MCPServer) *StdioServer {
+ return &StdioServer{
+ server: server,
+ errLogger: log.New(
+ os.Stderr,
+ "",
+ log.LstdFlags,
+ ), // Default to discarding logs
+ }
+}
+
+// SetErrorLogger configures where error messages from the StdioServer are logged.
+// The provided logger will receive all error messages generated during server operation.
+func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
+ s.errLogger = logger
+}
+
+// SetContextFunc sets a function that will be called to customise the context
+// to the server. Note that the stdio server uses the same context for all requests,
+// so this function will only be called once per server instance.
+func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
+ s.contextFunc = fn
+}
+
+// handleNotifications continuously processes notifications from the session's notification channel
+// and writes them to the provided output. It runs until the context is cancelled.
+// Any errors encountered while writing notifications are logged but do not stop the handler.
+func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
+ for {
+ select {
+ case notification := <-stdioSessionInstance.notifications:
+ if err := s.writeResponse(notification, stdout); err != nil {
+ s.errLogger.Printf("Error writing notification: %v", err)
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+// processInputStream continuously reads and processes messages from the input stream.
+// It handles EOF gracefully as a normal termination condition.
+// The function returns when either:
+// - The context is cancelled (returns context.Err())
+// - EOF is encountered (returns nil)
+// - An error occurs while reading or processing messages (returns the error)
+func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
+ for {
+ if err := ctx.Err(); err != nil {
+ return err
+ }
+
+ line, err := s.readNextLine(ctx, reader)
+ if err != nil {
+ if err == io.EOF {
+ return nil
+ }
+ s.errLogger.Printf("Error reading input: %v", err)
+ return err
+ }
+
+ if err := s.processMessage(ctx, line, stdout); err != nil {
+ if err == io.EOF {
+ return nil
+ }
+ s.errLogger.Printf("Error handling message: %v", err)
+ return err
+ }
+ }
+}
+
+// readNextLine reads a single line from the input reader in a context-aware manner.
+// It uses channels to make the read operation cancellable via context.
+// Returns the read line and any error encountered. If the context is cancelled,
+// returns an empty string and the context's error. EOF is returned when the input
+// stream is closed.
+func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
+ type result struct {
+ line string
+ err error
+ }
+
+ resultCh := make(chan result, 1)
+
+ go func() {
+ line, err := reader.ReadString('\n')
+ resultCh <- result{line: line, err: err}
+ }()
+
+ select {
+ case <-ctx.Done():
+ return "", nil
+ case res := <-resultCh:
+ return res.line, res.err
+ }
+}
+
+// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
+// It runs until the context is cancelled or an error occurs.
+// Returns an error if there are issues with reading input or writing output.
+func (s *StdioServer) Listen(
+ ctx context.Context,
+ stdin io.Reader,
+ stdout io.Writer,
+) error {
+ // Set a static client context since stdio only has one client
+ if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
+ return fmt.Errorf("register session: %w", err)
+ }
+ defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
+ ctx = s.server.WithContext(ctx, &stdioSessionInstance)
+
+ // Add in any custom context.
+ if s.contextFunc != nil {
+ ctx = s.contextFunc(ctx)
+ }
+
+ reader := bufio.NewReader(stdin)
+
+ // Start notification handler
+ go s.handleNotifications(ctx, stdout)
+ return s.processInputStream(ctx, reader, stdout)
+}
+
+// processMessage handles a single JSON-RPC message and writes the response.
+// It parses the message, processes it through the wrapped MCPServer, and writes any response.
+// Returns an error if there are issues with message processing or response writing.
+func (s *StdioServer) processMessage(
+ ctx context.Context,
+ line string,
+ writer io.Writer,
+) error {
+ // If line is empty, likely due to ctx cancellation
+ if len(line) == 0 {
+ return nil
+ }
+
+ // Parse the message as raw JSON
+ var rawMessage json.RawMessage
+ if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
+ response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
+ return s.writeResponse(response, writer)
+ }
+
+ // Handle the message using the wrapped server
+ response := s.server.HandleMessage(ctx, rawMessage)
+
+ // Only write response if there is one (not for notifications)
+ if response != nil {
+ if err := s.writeResponse(response, writer); err != nil {
+ return fmt.Errorf("failed to write response: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
+// Returns an error if marshaling or writing fails.
+func (s *StdioServer) writeResponse(
+ response mcp.JSONRPCMessage,
+ writer io.Writer,
+) error {
+ responseBytes, err := json.Marshal(response)
+ if err != nil {
+ return err
+ }
+
+ // Write response followed by newline
+ if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
+// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
+// Returns an error if the server encounters any issues during operation.
+func ServeStdio(server *MCPServer, opts ...StdioOption) error {
+ s := NewStdioServer(server)
+
+ for _, opt := range opts {
+ opt(s)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ // Set up signal handling
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
+
+ go func() {
+ <-sigChan
+ cancel()
+ }()
+
+ return s.Listen(ctx, os.Stdin, os.Stdout)
+}
@@ -0,0 +1,653 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/google/uuid"
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/util"
+)
+
+// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
+type StreamableHTTPOption func(*StreamableHTTPServer)
+
+// WithEndpointPath sets the endpoint path for the server.
+// The default is "/mcp".
+// It's only works for `Start` method. When used as a http.Handler, it has no effect.
+func WithEndpointPath(endpointPath string) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
+ normalizedPath := "/" + strings.Trim(endpointPath, "/")
+ s.endpointPath = normalizedPath
+ }
+}
+
+// WithStateLess sets the server to stateless mode.
+// If true, the server will manage no session information. Every request will be treated
+// as a new session. No session id returned to the client.
+// The default is false.
+//
+// Notice: This is a convenience method. It's identical to set WithSessionIdManager option
+// to StatelessSessionIdManager.
+func WithStateLess(stateLess bool) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.sessionIdManager = &StatelessSessionIdManager{}
+ }
+}
+
+// WithSessionIdManager sets a custom session id generator for the server.
+// By default, the server will use SimpleStatefulSessionIdGenerator, which generates
+// session ids with uuid, and it's insecure.
+// Notice: it will override the WithStateLess option.
+func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.sessionIdManager = manager
+ }
+}
+
+// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
+// server will send a heartbeat to the client through the GET connection, to keep
+// the connection alive from being closed by the network infrastructure (e.g.
+// gateways). If the client does not establish a GET connection, it has no
+// effect. The default is not to send heartbeats.
+func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.listenHeartbeatInterval = interval
+ }
+}
+
+// WithHTTPContextFunc sets a function that will be called to customise the context
+// to the server using the incoming request.
+// This can be used to inject context values from headers, for example.
+func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.contextFunc = fn
+ }
+}
+
+// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
+// NOTE: When providing a custom HTTP server, you must handle routing yourself
+// If routing is not set up, the server will start but won't handle any MCP requests.
+func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.httpServer = srv
+ }
+}
+
+// WithLogger sets the logger for the server
+func WithLogger(logger util.Logger) StreamableHTTPOption {
+ return func(s *StreamableHTTPServer) {
+ s.logger = logger
+ }
+}
+
+// StreamableHTTPServer implements a Streamable-http based MCP server.
+// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
+// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
+//
+// Usage:
+//
+// server := NewStreamableHTTPServer(mcpServer)
+// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
+//
+// or the server itself can be used as a http.Handler, which is convenient to
+// integrate with existing http servers, or advanced usage:
+//
+// handler := NewStreamableHTTPServer(mcpServer)
+// http.Handle("/streamable-http", handler)
+// http.ListenAndServe(":8080", nil)
+//
+// Notice:
+// Except for the GET handlers(listening), the POST handlers(request/notification) will
+// not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
+// or `hooks.onRegisterSession` will not be triggered for POST messages.
+//
+// The current implementation does not support the following features from the specification:
+// - Batching of requests/notifications/responses in arrays.
+// - Stream Resumability
+type StreamableHTTPServer struct {
+ server *MCPServer
+ sessionTools *sessionToolsStore
+ sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
+
+ httpServer *http.Server
+ mu sync.RWMutex
+
+ endpointPath string
+ contextFunc HTTPContextFunc
+ sessionIdManager SessionIdManager
+ listenHeartbeatInterval time.Duration
+ logger util.Logger
+}
+
+// NewStreamableHTTPServer creates a new streamable-http server instance
+func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
+ s := &StreamableHTTPServer{
+ server: server,
+ sessionTools: newSessionToolsStore(),
+ endpointPath: "/mcp",
+ sessionIdManager: &InsecureStatefulSessionIdManager{},
+ logger: util.DefaultLogger(),
+ }
+
+ // Apply all options
+ for _, opt := range opts {
+ opt(s)
+ }
+ return s
+}
+
+// ServeHTTP implements the http.Handler interface.
+func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case http.MethodPost:
+ s.handlePost(w, r)
+ case http.MethodGet:
+ s.handleGet(w, r)
+ case http.MethodDelete:
+ s.handleDelete(w, r)
+ default:
+ http.NotFound(w, r)
+ }
+}
+
+// Start begins serving the http server on the specified address and path
+// (endpointPath). like:
+//
+// s.Start(":8080")
+func (s *StreamableHTTPServer) Start(addr string) error {
+ s.mu.Lock()
+ if s.httpServer == nil {
+ mux := http.NewServeMux()
+ mux.Handle(s.endpointPath, s)
+ s.httpServer = &http.Server{
+ Addr: addr,
+ Handler: mux,
+ }
+ } else {
+ if s.httpServer.Addr == "" {
+ s.httpServer.Addr = addr
+ } else if s.httpServer.Addr != addr {
+ return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
+ }
+ }
+ srv := s.httpServer
+ s.mu.Unlock()
+
+ return srv.ListenAndServe()
+}
+
+// Shutdown gracefully stops the server, closing all active sessions
+// and shutting down the HTTP server.
+func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
+
+ // shutdown the server if needed (may use as a http.Handler)
+ s.mu.RLock()
+ srv := s.httpServer
+ s.mu.RUnlock()
+ if srv != nil {
+ return srv.Shutdown(ctx)
+ }
+ return nil
+}
+
+// --- internal methods ---
+
+const (
+ headerKeySessionID = "Mcp-Session-Id"
+)
+
+func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
+ // post request carry request/notification message
+
+ // Check content type
+ contentType := r.Header.Get("Content-Type")
+ if contentType != "application/json" {
+ http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
+ return
+ }
+
+ // Check the request body is valid json, meanwhile, get the request Method
+ rawData, err := io.ReadAll(r.Body)
+ if err != nil {
+ s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
+ return
+ }
+ var baseMessage struct {
+ Method mcp.MCPMethod `json:"method"`
+ }
+ if err := json.Unmarshal(rawData, &baseMessage); err != nil {
+ s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
+ return
+ }
+ isInitializeRequest := baseMessage.Method == mcp.MethodInitialize
+
+ // Prepare the session for the mcp server
+ // The session is ephemeral. Its life is the same as the request. It's only created
+ // for interaction with the mcp server.
+ var sessionID string
+ if isInitializeRequest {
+ // generate a new one for initialize request
+ sessionID = s.sessionIdManager.Generate()
+ } else {
+ // Get session ID from header.
+ // Stateful servers need the client to carry the session ID.
+ sessionID = r.Header.Get(headerKeySessionID)
+ isTerminated, err := s.sessionIdManager.Validate(sessionID)
+ if err != nil {
+ http.Error(w, "Invalid session ID", http.StatusBadRequest)
+ return
+ }
+ if isTerminated {
+ http.Error(w, "Session terminated", http.StatusNotFound)
+ return
+ }
+ }
+
+ session := newStreamableHttpSession(sessionID, s.sessionTools)
+
+ // Set the client context before handling the message
+ ctx := s.server.WithContext(r.Context(), session)
+ if s.contextFunc != nil {
+ ctx = s.contextFunc(ctx, r)
+ }
+
+ // handle potential notifications
+ mu := sync.Mutex{}
+ upgradedHeader := false
+ done := make(chan struct{})
+
+ go func() {
+ for {
+ select {
+ case nt := <-session.notificationChannel:
+ func() {
+ mu.Lock()
+ defer mu.Unlock()
+ // if the done chan is closed, as the request is terminated, just return
+ select {
+ case <-done:
+ return
+ default:
+ }
+ defer func() {
+ flusher, ok := w.(http.Flusher)
+ if ok {
+ flusher.Flush()
+ }
+ }()
+
+ // if there's notifications, upgradedHeader to SSE response
+ if !upgradedHeader {
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.WriteHeader(http.StatusAccepted)
+ upgradedHeader = true
+ }
+ err := writeSSEEvent(w, nt)
+ if err != nil {
+ s.logger.Errorf("Failed to write SSE event: %v", err)
+ return
+ }
+ }()
+ case <-done:
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ // Process message through MCPServer
+ response := s.server.HandleMessage(ctx, rawData)
+ if response == nil {
+ // For notifications, just send 202 Accepted with no body
+ w.WriteHeader(http.StatusAccepted)
+ return
+ }
+
+ // Write response
+ mu.Lock()
+ defer mu.Unlock()
+ // close the done chan before unlock
+ defer close(done)
+ if ctx.Err() != nil {
+ return
+ }
+ // If client-server communication already upgraded to SSE stream
+ if session.upgradeToSSE.Load() {
+ if !upgradedHeader {
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Connection", "keep-alive")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.WriteHeader(http.StatusAccepted)
+ upgradedHeader = true
+ }
+ if err := writeSSEEvent(w, response); err != nil {
+ s.logger.Errorf("Failed to write final SSE response event: %v", err)
+ }
+ } else {
+ w.Header().Set("Content-Type", "application/json")
+ if isInitializeRequest && sessionID != "" {
+ // send the session ID back to the client
+ w.Header().Set(headerKeySessionID, sessionID)
+ }
+ w.WriteHeader(http.StatusOK)
+ err := json.NewEncoder(w).Encode(response)
+ if err != nil {
+ s.logger.Errorf("Failed to write response: %v", err)
+ }
+ }
+}
+
+func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
+ // get request is for listening to notifications
+ // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
+
+ sessionID := r.Header.Get(headerKeySessionID)
+ // the specification didn't say we should validate the session id
+
+ if sessionID == "" {
+ // It's a stateless server,
+ // but the MCP server requires a unique ID for registering, so we use a random one
+ sessionID = uuid.New().String()
+ }
+
+ session := newStreamableHttpSession(sessionID, s.sessionTools)
+ if err := s.server.RegisterSession(r.Context(), session); err != nil {
+ http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
+ return
+ }
+ defer s.server.UnregisterSession(r.Context(), sessionID)
+
+ // Set the client context before handling the message
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+ w.WriteHeader(http.StatusAccepted)
+
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
+ return
+ }
+ flusher.Flush()
+
+ // Start notification handler for this session
+ done := make(chan struct{})
+ defer close(done)
+ writeChan := make(chan any, 16)
+
+ go func() {
+ for {
+ select {
+ case nt := <-session.notificationChannel:
+ select {
+ case writeChan <- &nt:
+ case <-done:
+ return
+ }
+ case <-done:
+ return
+ }
+ }
+ }()
+
+ if s.listenHeartbeatInterval > 0 {
+ // heartbeat to keep the connection alive
+ go func() {
+ ticker := time.NewTicker(s.listenHeartbeatInterval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-ticker.C:
+ message := mcp.JSONRPCRequest{
+ JSONRPC: "2.0",
+ ID: mcp.NewRequestId(s.nextRequestID(sessionID)),
+ Request: mcp.Request{
+ Method: "ping",
+ },
+ }
+ select {
+ case writeChan <- message:
+ case <-done:
+ return
+ }
+ case <-done:
+ return
+ }
+ }
+ }()
+ }
+
+ // Keep the connection open until the client disconnects
+ //
+ // There's will a Available() check when handler ends, and it maybe race with Flush(),
+ // so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
+ for {
+ select {
+ case data := <-writeChan:
+ if data == nil {
+ continue
+ }
+ if err := writeSSEEvent(w, data); err != nil {
+ s.logger.Errorf("Failed to write SSE event: %v", err)
+ return
+ }
+ flusher.Flush()
+ case <-r.Context().Done():
+ return
+ }
+ }
+}
+
+func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
+ // delete request terminate the session
+ sessionID := r.Header.Get(headerKeySessionID)
+ notAllowed, err := s.sessionIdManager.Terminate(sessionID)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
+ return
+ }
+ if notAllowed {
+ http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ // remove the session relateddata from the sessionToolsStore
+ s.sessionTools.delete(sessionID)
+
+ // remove current session's requstID information
+ s.sessionRequestIDs.Delete(sessionID)
+
+ w.WriteHeader(http.StatusOK)
+}
+
+func writeSSEEvent(w io.Writer, data any) error {
+ jsonData, err := json.Marshal(data)
+ if err != nil {
+ return fmt.Errorf("failed to marshal data: %w", err)
+ }
+ _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
+ if err != nil {
+ return fmt.Errorf("failed to write SSE event: %w", err)
+ }
+ return nil
+}
+
+// writeJSONRPCError writes a JSON-RPC error response with the given error details.
+func (s *StreamableHTTPServer) writeJSONRPCError(
+ w http.ResponseWriter,
+ id any,
+ code int,
+ message string,
+) {
+ response := createErrorResponse(id, code, message)
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusBadRequest)
+ err := json.NewEncoder(w).Encode(response)
+ if err != nil {
+ s.logger.Errorf("Failed to write JSONRPCError: %v", err)
+ }
+}
+
+// nextRequestID gets the next incrementing requestID for the current session
+func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
+ actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
+ counter := actual.(*atomic.Int64)
+ return counter.Add(1)
+}
+
+// --- session ---
+
+type sessionToolsStore struct {
+ mu sync.RWMutex
+ tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
+}
+
+func newSessionToolsStore() *sessionToolsStore {
+ return &sessionToolsStore{
+ tools: make(map[string]map[string]ServerTool),
+ }
+}
+
+func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+ return s.tools[sessionID]
+}
+
+func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.tools[sessionID] = tools
+}
+
+func (s *sessionToolsStore) delete(sessionID string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tools, sessionID)
+}
+
+// streamableHttpSession is a session for streamable-http transport
+// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
+// When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
+type streamableHttpSession struct {
+ sessionID string
+ notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
+ tools *sessionToolsStore
+ upgradeToSSE atomic.Bool
+}
+
+func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession {
+ return &streamableHttpSession{
+ sessionID: sessionID,
+ notificationChannel: make(chan mcp.JSONRPCNotification, 100),
+ tools: toolStore,
+ }
+}
+
+func (s *streamableHttpSession) SessionID() string {
+ return s.sessionID
+}
+
+func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
+ return s.notificationChannel
+}
+
+func (s *streamableHttpSession) Initialize() {
+ // do nothing
+ // the session is ephemeral, no real initialized action needed
+}
+
+func (s *streamableHttpSession) Initialized() bool {
+ // the session is ephemeral, no real initialized action needed
+ return true
+}
+
+var _ ClientSession = (*streamableHttpSession)(nil)
+
+func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
+ return s.tools.get(s.sessionID)
+}
+
+func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
+ s.tools.set(s.sessionID, tools)
+}
+
+var _ SessionWithTools = (*streamableHttpSession)(nil)
+
+func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
+ s.upgradeToSSE.Store(true)
+}
+
+var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
+
+// --- session id manager ---
+
+type SessionIdManager interface {
+ Generate() string
+ // Validate checks if a session ID is valid and not terminated.
+ // Returns isTerminated=true if the ID is valid but belongs to a terminated session.
+ // Returns err!=nil if the ID format is invalid or lookup failed.
+ Validate(sessionID string) (isTerminated bool, err error)
+ // Terminate marks a session ID as terminated.
+ // Returns isNotAllowed=true if the server policy prevents client termination.
+ // Returns err!=nil if the ID is invalid or termination failed.
+ Terminate(sessionID string) (isNotAllowed bool, err error)
+}
+
+// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
+type StatelessSessionIdManager struct{}
+
+func (s *StatelessSessionIdManager) Generate() string {
+ return ""
+}
+func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
+ // In stateless mode, ignore session IDs completely - don't validate or reject them
+ return false, nil
+}
+func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
+ return false, nil
+}
+
+// InsecureStatefulSessionIdManager generate id with uuid
+// It won't validate the id indeed, so it could be fake.
+// For more secure session id, use a more complex generator, like a JWT.
+type InsecureStatefulSessionIdManager struct{}
+
+const idPrefix = "mcp-session-"
+
+func (s *InsecureStatefulSessionIdManager) Generate() string {
+ return idPrefix + uuid.New().String()
+}
+func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
+ // validate the session id is a valid uuid
+ if !strings.HasPrefix(sessionID, idPrefix) {
+ return false, fmt.Errorf("invalid session id: %s", sessionID)
+ }
+ if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
+ return false, fmt.Errorf("invalid session id: %s", sessionID)
+ }
+ return false, nil
+}
+func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
+ return false, nil
+}
+
+// NewTestStreamableHTTPServer creates a test server for testing purposes
+func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
+ sseServer := NewStreamableHTTPServer(server, opts...)
+ testServer := httptest.NewServer(sseServer)
+ return testServer
+}
@@ -0,0 +1,33 @@
+package util
+
+import (
+ "log"
+)
+
+// Logger defines a minimal logging interface
+type Logger interface {
+ Infof(format string, v ...any)
+ Errorf(format string, v ...any)
+}
+
+// --- Standard Library Logger Wrapper ---
+
+// DefaultStdLogger implements Logger using the standard library's log.Logger.
+func DefaultLogger() Logger {
+ return &stdLogger{
+ logger: log.Default(),
+ }
+}
+
+// stdLogger wraps the standard library's log.Logger.
+type stdLogger struct {
+ logger *log.Logger
+}
+
+func (l *stdLogger) Infof(format string, v ...any) {
+ l.logger.Printf("INFO: "+format, v...)
+}
+
+func (l *stdLogger) Errorf(format string, v ...any) {
+ l.logger.Printf("ERROR: "+format, v...)
+}
@@ -1,3 +1,3 @@
{
- ".": "0.1.0-beta.2"
+ ".": "1.8.2"
}
@@ -1,2 +1,4 @@
-configured_endpoints: 80
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-5ad6884898c07591750dde560118baf7074a59aecd1f367f930c5e42b04e848a.yml
+configured_endpoints: 97
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-a473967d1766dc155994d932fbc4a5bcbd1c140a37c20d0a4065e1bf0640536d.yml
+openapi_spec_hash: 67cdc62b0d6c8b1de29b7dc54b265749
+config_hash: 05c7d4a6f4d5983fe9550457114b47dd
@@ -1,5 +1,357 @@
# Changelog
+## 1.8.2 (2025-06-27)
+
+Full Changelog: [v1.8.1...v1.8.2](https://github.com/openai/openai-go/compare/v1.8.1...v1.8.2)
+
+### Bug Fixes
+
+* don't try to deserialize as json when ResponseBodyInto is []byte ([74ad0f8](https://github.com/openai/openai-go/commit/74ad0f8fab0f956234503a9ba26fbd395944dcf8))
+* **pagination:** check if page data is empty in GetNextPage ([c9becdc](https://github.com/openai/openai-go/commit/c9becdc9908f2a1961160837c6ab8cd9064e7854))
+
+## 1.8.1 (2025-06-26)
+
+Full Changelog: [v1.8.0...v1.8.1](https://github.com/openai/openai-go/compare/v1.8.0...v1.8.1)
+
+### Chores
+
+* **api:** remove unsupported property ([e22316a](https://github.com/openai/openai-go/commit/e22316adcd8f2c5aa672b12453cbd287de0e1878))
+* **docs:** update README to include links to docs on Webhooks ([7bb8f85](https://github.com/openai/openai-go/commit/7bb8f8549fdd98997b1d145cbae98ff0146b4e43))
+
+## 1.8.0 (2025-06-26)
+
+Full Changelog: [v1.7.0...v1.8.0](https://github.com/openai/openai-go/compare/v1.7.0...v1.8.0)
+
+### Features
+
+* **api:** webhook and deep research support ([f6a7e7d](https://github.com/openai/openai-go/commit/f6a7e7dcd8801facc4f8d981f1ca43786c10de1e))
+
+
+### Chores
+
+* **internal:** add tests for breaking change detection ([339522d](https://github.com/openai/openai-go/commit/339522d38cd31b0753a8df37b8924f7e7dfb0b1d))
+
+## 1.7.0 (2025-06-23)
+
+Full Changelog: [v1.6.0...v1.7.0](https://github.com/openai/openai-go/compare/v1.6.0...v1.7.0)
+
+### Features
+
+* **api:** make model and inputs not required to create response ([19f0b76](https://github.com/openai/openai-go/commit/19f0b76378d35b3d81c60c85bf2e64d6bf85b9c2))
+* **api:** update api shapes for usage and code interpreter ([d24d42c](https://github.com/openai/openai-go/commit/d24d42cba60e565627e8ffb1cac63a5085ddb6da))
+* **client:** add escape hatch for null slice & maps ([9c633d6](https://github.com/openai/openai-go/commit/9c633d6f1dbcc0b153f42f831ee7e13d6fe62296))
+
+
+### Chores
+
+* fix documentation of null map ([8f3a134](https://github.com/openai/openai-go/commit/8f3a134e500b1b7791ab855adaef2d7b10d2d1c3))
+
+## 1.6.0 (2025-06-17)
+
+Full Changelog: [v1.5.0...v1.6.0](https://github.com/openai/openai-go/compare/v1.5.0...v1.6.0)
+
+### Features
+
+* **api:** add reusable prompt IDs ([280c698](https://github.com/openai/openai-go/commit/280c698015eba5f6bd47e2fce038eb401f6ef0f2))
+* **api:** manual updates ([740f840](https://github.com/openai/openai-go/commit/740f84006ac283a25f5ad96aaf845a3c8a51c6ac))
+* **client:** add debug log helper ([5715c49](https://github.com/openai/openai-go/commit/5715c491c483f8dab4ea2a900c400384f6810024))
+
+
+### Chores
+
+* **ci:** enable for pull requests ([9ed793a](https://github.com/openai/openai-go/commit/9ed793a51010423db464a7b7bd263d2fd275967f))
+
+## 1.5.0 (2025-06-10)
+
+Full Changelog: [v1.4.0...v1.5.0](https://github.com/openai/openai-go/compare/v1.4.0...v1.5.0)
+
+### Features
+
+* **api:** Add o3-pro model IDs ([3bbd0b8](https://github.com/openai/openai-go/commit/3bbd0b8f09030a6c571900d444742c4fc2a3c211))
+
+## 1.4.0 (2025-06-09)
+
+Full Changelog: [v1.3.0...v1.4.0](https://github.com/openai/openai-go/compare/v1.3.0...v1.4.0)
+
+### Features
+
+* **client:** allow overriding unions ([27c6299](https://github.com/openai/openai-go/commit/27c6299cb4ac275c6542b5691d81b795e65eeff6))
+
+
+### Bug Fixes
+
+* **client:** cast to raw message when converting to params ([a3282b0](https://github.com/openai/openai-go/commit/a3282b01a8d9a2c0cd04f24b298bf2ffcd160ebd))
+
+## 1.3.0 (2025-06-03)
+
+Full Changelog: [v1.2.1...v1.3.0](https://github.com/openai/openai-go/compare/v1.2.1...v1.3.0)
+
+### Features
+
+* **api:** add new realtime and audio models, realtime session options ([8b8f62b](https://github.com/openai/openai-go/commit/8b8f62b8e185f3fe4aaa99e892df5d35638931a1))
+
+## 1.2.1 (2025-06-02)
+
+Full Changelog: [v1.2.0...v1.2.1](https://github.com/openai/openai-go/compare/v1.2.0...v1.2.1)
+
+### Bug Fixes
+
+* **api:** Fix evals and code interpreter interfaces ([7e244c7](https://github.com/openai/openai-go/commit/7e244c73caad6b4768cced9a798452f03b1165c8))
+* fix error ([a200fca](https://github.com/openai/openai-go/commit/a200fca92c3fa413cf724f424077d1537fa2ca3e))
+
+
+### Chores
+
+* make go mod tidy continue on error ([48f41c2](https://github.com/openai/openai-go/commit/48f41c2993bf6181018da859ae759951261f9ee2))
+
+## 1.2.0 (2025-05-29)
+
+Full Changelog: [v1.1.0...v1.2.0](https://github.com/openai/openai-go/compare/v1.1.0...v1.2.0)
+
+### Features
+
+* **api:** Config update for pakrym-stream-param ([84d59d5](https://github.com/openai/openai-go/commit/84d59d5cbc7521ddcc04435317903fd4ec3d17f6))
+
+
+### Bug Fixes
+
+* **client:** return binary content from `get /containers/{container_id}/files/{file_id}/content` ([f8c8de1](https://github.com/openai/openai-go/commit/f8c8de18b720b224267d54da53d7d919ed0fdff3))
+
+
+### Chores
+
+* deprecate Assistants API ([027470e](https://github.com/openai/openai-go/commit/027470e066ea6bbca1aeeb4fb9a8a3430babb84c))
+* **internal:** fix release workflows ([fd46533](https://github.com/openai/openai-go/commit/fd4653316312755ccab7435fca9fb0a2d8bf8fbb))
+
+## 1.1.0 (2025-05-22)
+
+Full Changelog: [v1.0.0...v1.1.0](https://github.com/openai/openai-go/compare/v1.0.0...v1.1.0)
+
+### Features
+
+* **api:** add container endpoint ([2bd777d](https://github.com/openai/openai-go/commit/2bd777d6813b5dfcd3a2d339047a944c478dcd64))
+* **api:** new API tools ([e7e2123](https://github.com/openai/openai-go/commit/e7e2123de7cafef515e07adde6edd45a7035b610))
+* **api:** new streaming helpers for background responses ([422a0db](https://github.com/openai/openai-go/commit/422a0db3c674135e23dd200f5d8d785bd0be33e6))
+
+
+### Chores
+
+* **docs:** grammar improvements ([f4b23dd](https://github.com/openai/openai-go/commit/f4b23dd31facfc8839310854521b48060ef76be2))
+* improve devcontainer setup ([dfdaeec](https://github.com/openai/openai-go/commit/dfdaeec2d6dd5cd679514d60c49b68c5df9e1b1e))
+
+## 1.0.0 (2025-05-19)
+
+Full Changelog: [v0.1.0-beta.11...v1.0.0](https://github.com/openai/openai-go/compare/v0.1.0-beta.11...v1.0.0)
+
+### โ BREAKING CHANGES
+
+* **client:** rename file array param variant
+* **api:** improve naming and remove assistants
+* **accumulator:** update casing ([#401](https://github.com/openai/openai-go/issues/401))
+
+### Features
+
+* **api:** improve naming and remove assistants ([4c623b8](https://github.com/openai/openai-go/commit/4c623b88a9025db1961cc57985eb7374342f43e7))
+
+
+### Bug Fixes
+
+* **accumulator:** update casing ([#401](https://github.com/openai/openai-go/issues/401)) ([d59453c](https://github.com/openai/openai-go/commit/d59453c95b89fdd0b51305778dec0a39ce3a9d2a))
+* **client:** correctly set stream key for multipart ([0ec68f0](https://github.com/openai/openai-go/commit/0ec68f0d779e7726931b1115eca9ae81eab59ba8))
+* **client:** don't panic on marshal with extra null field ([9c15332](https://github.com/openai/openai-go/commit/9c153320272d212beaa516d4c70d54ae8053a958))
+* **client:** increase max stream buffer size ([9456455](https://github.com/openai/openai-go/commit/945645559c5d68d9e28cf445d9c3b83e5fc6bd35))
+* **client:** rename file array param variant ([4cfcf86](https://github.com/openai/openai-go/commit/4cfcf869280e7531fbbc8c00db0dd9271d07c423))
+* **client:** use scanner for streaming ([aa58806](https://github.com/openai/openai-go/commit/aa58806bffc3aed68425c480414ddbb4dac3fa78))
+
+
+### Chores
+
+* **docs:** typo fix ([#400](https://github.com/openai/openai-go/issues/400)) ([bececf2](https://github.com/openai/openai-go/commit/bececf24cd0324b7c991b7d7f1d3eff6bf71f996))
+* **examples:** migrate enum ([#447](https://github.com/openai/openai-go/issues/447)) ([814dd8b](https://github.com/openai/openai-go/commit/814dd8b6cfe4eeb535dc8ecd161a409ea2eb6698))
+* **examples:** migrate to latest version ([#444](https://github.com/openai/openai-go/issues/444)) ([1c8754f](https://github.com/openai/openai-go/commit/1c8754ff905ed023f6381c8493910d63039407de))
+* **examples:** remove beta assisstants examples ([#445](https://github.com/openai/openai-go/issues/445)) ([5891583](https://github.com/openai/openai-go/commit/589158372be9c0517b5508f9ccd872fdb1fe480b))
+* **example:** update fine-tuning ([#450](https://github.com/openai/openai-go/issues/450)) ([421e3c5](https://github.com/openai/openai-go/commit/421e3c5065ace2d5ddd3d13a036477fff9123e5f))
+
+## 0.1.0-beta.11 (2025-05-16)
+
+Full Changelog: [v0.1.0-beta.10...v0.1.0-beta.11](https://github.com/openai/openai-go/compare/v0.1.0-beta.10...v0.1.0-beta.11)
+
+### โ BREAKING CHANGES
+
+* **client:** clearer array variant names
+* **client:** rename resp package
+* **client:** improve core function names
+* **client:** improve union variant names
+* **client:** improve param subunions & deduplicate types
+
+### Features
+
+* **api:** add image sizes, reasoning encryption ([0852fb3](https://github.com/openai/openai-go/commit/0852fb3101dc940761f9e4f32875bfcf3669eada))
+* **api:** add o3 and o4-mini model IDs ([3fabca6](https://github.com/openai/openai-go/commit/3fabca6b5c610edfb7bcd0cab5334a06444df0b0))
+* **api:** Add reinforcement fine-tuning api support ([831a124](https://github.com/openai/openai-go/commit/831a12451cfce907b5ae4d294b9c2ac95f40d97a))
+* **api:** adding gpt-4.1 family of model IDs ([1ef19d4](https://github.com/openai/openai-go/commit/1ef19d4cc94992dc435d7d5f28b30c9b1d255cd4))
+* **api:** adding new image model support ([bf17880](https://github.com/openai/openai-go/commit/bf17880e182549c5c0fc34ec05df3184f223bc00))
+* **api:** manual updates ([11f5716](https://github.com/openai/openai-go/commit/11f5716afa86aa100f80f3fa127e1d49203e5e21))
+* **api:** responses x eval api ([183aaf7](https://github.com/openai/openai-go/commit/183aaf700f1d7ffad4ac847627d9ace65379c459))
+* **api:** Updating Assistants and Evals API schemas ([47ca619](https://github.com/openai/openai-go/commit/47ca619fa1b439cf3a68c98e48e9bf1942f0568b))
+* **client:** add dynamic streaming buffer to handle large lines ([8e6aad6](https://github.com/openai/openai-go/commit/8e6aad6d54fc73f1fcc174e1f06c9b3cf00c2689))
+* **client:** add helper method to generate constant structs ([ff82809](https://github.com/openai/openai-go/commit/ff828094b561fc11184fed83f04424b6f68f7781))
+* **client:** add support for endpoint-specific base URLs in python ([072dce4](https://github.com/openai/openai-go/commit/072dce46486d373fa0f0de5415f5270b01c2d972))
+* **client:** add support for reading base URL from environment variable ([0d37268](https://github.com/openai/openai-go/commit/0d372687d673990290bad583f1906a2b121960b2))
+* **client:** clearer array variant names ([a5d8b5d](https://github.com/openai/openai-go/commit/a5d8b5d6b161e3083184586840b2cbe0606d8de1))
+* **client:** experimental support for unmarshalling into param structs ([5234875](https://github.com/openai/openai-go/commit/523487582e15a47e2f409f183568551258f4b8fe))
+* **client:** improve param subunions & deduplicate types ([8a78f37](https://github.com/openai/openai-go/commit/8a78f37c25abf10498d16d210de3078f491ff23e))
+* **client:** rename resp package ([4433516](https://github.com/openai/openai-go/commit/443351625ee290937a25425719b099ce785bd21b))
+* **client:** support more time formats ([ec171b2](https://github.com/openai/openai-go/commit/ec171b2405c46f9cf04560760da001f7133d2fec))
+* fix lint ([9c50a1e](https://github.com/openai/openai-go/commit/9c50a1eb9f93b578cb78085616f6bfab69f21dbc))
+
+
+### Bug Fixes
+
+* **client:** clean up reader resources ([710b92e](https://github.com/openai/openai-go/commit/710b92eaa7e94c03aeeca7479668677b32acb154))
+* **client:** correctly update body in WithJSONSet ([f2d7118](https://github.com/openai/openai-go/commit/f2d7118295dd3073aa449426801d02e6f60bdaa3))
+* **client:** improve core function names ([9f312a9](https://github.com/openai/openai-go/commit/9f312a9b14f5424d44d5834f1b82f3d3fcd57db2))
+* **client:** improve union variant names ([a2c3de9](https://github.com/openai/openai-go/commit/a2c3de9e6c9f6e406b953f6de2eb78d1e72ec1b5))
+* **client:** include path for type names in example code ([69561c5](https://github.com/openai/openai-go/commit/69561c549e18bd16a3641d62769479b125a4e955))
+* **client:** resolve issue with optional multipart files ([910d173](https://github.com/openai/openai-go/commit/910d1730e97a03898e5dee7c889844a2ccec3e56))
+* **client:** time format encoding fix ([ca17553](https://github.com/openai/openai-go/commit/ca175533ac8a17d36be1f531bbaa89c770da3f58))
+* **client:** unmarshal responses properly ([fc9fec3](https://github.com/openai/openai-go/commit/fc9fec3c466ba9f633c3f7a4eebb5ebd3b85e8ac))
+* handle empty bodies in WithJSONSet ([8372464](https://github.com/openai/openai-go/commit/83724640c6c00dcef1547dcabace309f17d14afc))
+* **pagination:** handle errors when applying options ([eebf84b](https://github.com/openai/openai-go/commit/eebf84bf19f0eb6d9fa21e64bb83b0258e8cb42c))
+
+
+### Chores
+
+* **ci:** add timeout thresholds for CI jobs ([26b0dd7](https://github.com/openai/openai-go/commit/26b0dd760c142ca3aa287e8441bbe44cc8b3be0b))
+* **ci:** only use depot for staging repos ([7682154](https://github.com/openai/openai-go/commit/7682154fdbcbe2a2ffdb2df590647a1712d52275))
+* **ci:** run on more branches and use depot runners ([d7badbc](https://github.com/openai/openai-go/commit/d7badbc0d17bcf3cffec332f65cb68e531cb3176))
+* **docs:** document pre-request options ([4befa5a](https://github.com/openai/openai-go/commit/4befa5a48ca61372715f36c45e72eb159d95bf2d))
+* **docs:** update respjson package name ([9a00229](https://github.com/openai/openai-go/commit/9a002299a91e1145f053c51b1a4de10298fd2f43))
+* **readme:** improve formatting ([a847e8d](https://github.com/openai/openai-go/commit/a847e8df45f725f9652fcea53ce57d3b9046efc7))
+* **utils:** add internal resp to param utility ([239c4e2](https://github.com/openai/openai-go/commit/239c4e2cb32c7af71ab14668ccc2f52ea59653f9))
+
+
+### Documentation
+
+* update documentation links to be more uniform ([f5f0bb0](https://github.com/openai/openai-go/commit/f5f0bb05ee705d84119806f8e703bf2e0becb1fa))
+
+## 0.1.0-beta.10 (2025-04-14)
+
+Full Changelog: [v0.1.0-beta.9...v0.1.0-beta.10](https://github.com/openai/openai-go/compare/v0.1.0-beta.9...v0.1.0-beta.10)
+
+### Chores
+
+* **internal:** expand CI branch coverage ([#369](https://github.com/openai/openai-go/issues/369)) ([258dda8](https://github.com/openai/openai-go/commit/258dda8007a69b9c2720b225ee6d27474d676a93))
+* **internal:** reduce CI branch coverage ([a2f7c03](https://github.com/openai/openai-go/commit/a2f7c03eb984d98f29f908df103ea1743f2e3d9a))
+
+## 0.1.0-beta.9 (2025-04-09)
+
+Full Changelog: [v0.1.0-beta.8...v0.1.0-beta.9](https://github.com/openai/openai-go/compare/v0.1.0-beta.8...v0.1.0-beta.9)
+
+### Chores
+
+* workaround build errors ([#366](https://github.com/openai/openai-go/issues/366)) ([adeb003](https://github.com/openai/openai-go/commit/adeb003cab8efbfbf4424e03e96a0f5e728551cb))
+
+## 0.1.0-beta.8 (2025-04-09)
+
+Full Changelog: [v0.1.0-beta.7...v0.1.0-beta.8](https://github.com/openai/openai-go/compare/v0.1.0-beta.7...v0.1.0-beta.8)
+
+### Features
+
+* **api:** Add evalapi to sdk ([#360](https://github.com/openai/openai-go/issues/360)) ([88977d1](https://github.com/openai/openai-go/commit/88977d1868dbbe0060c56ba5dac8eb19773e4938))
+* **api:** manual updates ([#363](https://github.com/openai/openai-go/issues/363)) ([5d068e0](https://github.com/openai/openai-go/commit/5d068e0053172db7f5b75038aa215eee074eeeed))
+* **client:** add escape hatch to omit required param fields ([#354](https://github.com/openai/openai-go/issues/354)) ([9690d6b](https://github.com/openai/openai-go/commit/9690d6b49f8b00329afc038ec15116750853e620))
+* **client:** support custom http clients ([#357](https://github.com/openai/openai-go/issues/357)) ([b5a624f](https://github.com/openai/openai-go/commit/b5a624f658cad774094427b36b05e446b41e8c52))
+
+
+### Chores
+
+* **docs:** readme improvements ([#356](https://github.com/openai/openai-go/issues/356)) ([b2f8539](https://github.com/openai/openai-go/commit/b2f8539d6316e3443aa733be2c95926696119c13))
+* **internal:** fix examples ([#361](https://github.com/openai/openai-go/issues/361)) ([de398b4](https://github.com/openai/openai-go/commit/de398b453d398299eb80c15f8fdb2bcbef5eeed6))
+* **internal:** skip broken test ([#362](https://github.com/openai/openai-go/issues/362)) ([cccead9](https://github.com/openai/openai-go/commit/cccead9ba916142ac8fbe6e8926d706511e32ae3))
+* **tests:** improve enum examples ([#359](https://github.com/openai/openai-go/issues/359)) ([e0b9739](https://github.com/openai/openai-go/commit/e0b9739920114d6e991d3947b67fdf62cfaa09c7))
+
+## 0.1.0-beta.7 (2025-04-07)
+
+Full Changelog: [v0.1.0-beta.6...v0.1.0-beta.7](https://github.com/openai/openai-go/compare/v0.1.0-beta.6...v0.1.0-beta.7)
+
+### Features
+
+* **client:** make response union's AsAny method type safe ([#352](https://github.com/openai/openai-go/issues/352)) ([1252f56](https://github.com/openai/openai-go/commit/1252f56c917e57d6d2b031501b2ff5f89f87cf87))
+
+
+### Chores
+
+* **docs:** doc improvements ([#350](https://github.com/openai/openai-go/issues/350)) ([80debc8](https://github.com/openai/openai-go/commit/80debc824eaacb4b07c8f3e8b1d0488d860d5be5))
+
+## 0.1.0-beta.6 (2025-04-04)
+
+Full Changelog: [v0.1.0-beta.5...v0.1.0-beta.6](https://github.com/openai/openai-go/compare/v0.1.0-beta.5...v0.1.0-beta.6)
+
+### Features
+
+* **api:** manual updates ([4e39609](https://github.com/openai/openai-go/commit/4e39609d499b88039f1c90cc4b56e26f28fd58ea))
+* **client:** support unions in query and forms ([#347](https://github.com/openai/openai-go/issues/347)) ([cf8af37](https://github.com/openai/openai-go/commit/cf8af373ab7c019c75e886855009ffaca320d0e3))
+
+## 0.1.0-beta.5 (2025-04-03)
+
+Full Changelog: [v0.1.0-beta.4...v0.1.0-beta.5](https://github.com/openai/openai-go/compare/v0.1.0-beta.4...v0.1.0-beta.5)
+
+### Features
+
+* **api:** manual updates ([563cc50](https://github.com/openai/openai-go/commit/563cc505f2ab17749bb77e937342a6614243b975))
+* **client:** omitzero on required id parameter ([#339](https://github.com/openai/openai-go/issues/339)) ([c0b4842](https://github.com/openai/openai-go/commit/c0b484266ccd9faee66873916d8c0c92ea9f1014))
+
+
+### Bug Fixes
+
+* **client:** return error on bad custom url instead of panic ([#341](https://github.com/openai/openai-go/issues/341)) ([a06c5e6](https://github.com/openai/openai-go/commit/a06c5e632242e53d3fdcc8964931acb533a30b7e))
+* **client:** support multipart encoding array formats ([#342](https://github.com/openai/openai-go/issues/342)) ([5993b28](https://github.com/openai/openai-go/commit/5993b28309d02c2d748b54d98934ef401dcd193a))
+* **client:** unmarshal stream events into fresh memory ([#340](https://github.com/openai/openai-go/issues/340)) ([52c3e08](https://github.com/openai/openai-go/commit/52c3e08f51d471d728e5acd16b3c304b51be2d03))
+
+## 0.1.0-beta.4 (2025-04-02)
+
+Full Changelog: [v0.1.0-beta.3...v0.1.0-beta.4](https://github.com/openai/openai-go/compare/v0.1.0-beta.3...v0.1.0-beta.4)
+
+### Features
+
+* **api:** manual updates ([bc4fe73](https://github.com/openai/openai-go/commit/bc4fe73eec9c4d39229e4beae8eaafb55b1d3364))
+* **api:** manual updates ([aa7ff10](https://github.com/openai/openai-go/commit/aa7ff10b0616a6b2ece45cb10e9c83f25e35aded))
+
+
+### Chores
+
+* **docs:** update file uploads in README ([#333](https://github.com/openai/openai-go/issues/333)) ([471c452](https://github.com/openai/openai-go/commit/471c4525c94e83cf4b78cb6c9b2f65a8a27bf3ce))
+* **internal:** codegen related update ([#335](https://github.com/openai/openai-go/issues/335)) ([48422dc](https://github.com/openai/openai-go/commit/48422dcca333ab808ccb02506c033f1c69d2aa19))
+* Remove deprecated/unused remote spec feature ([c5077a1](https://github.com/openai/openai-go/commit/c5077a154a6db79b73cf4978bdc08212c6da6423))
+
+## 0.1.0-beta.3 (2025-03-28)
+
+Full Changelog: [v0.1.0-beta.2...v0.1.0-beta.3](https://github.com/openai/openai-go/compare/v0.1.0-beta.2...v0.1.0-beta.3)
+
+### โ BREAKING CHANGES
+
+* **client:** add enums ([#327](https://github.com/openai/openai-go/issues/327))
+
+### Features
+
+* **api:** add `get /chat/completions` endpoint ([e8ed116](https://github.com/openai/openai-go/commit/e8ed1168576c885cb26fbf819b9c8d24975749bd))
+* **api:** add `get /responses/{response_id}/input_items` endpoint ([8870c26](https://github.com/openai/openai-go/commit/8870c26f010a596adcf37ac10dba096bdd4394e3))
+
+
+### Bug Fixes
+
+* **client:** add enums ([#327](https://github.com/openai/openai-go/issues/327)) ([b0e3afb](https://github.com/openai/openai-go/commit/b0e3afbd6f18fd9fc2a5ea9174bd7ec0ac0614db))
+
+
+### Chores
+
+* add hash of OpenAPI spec/config inputs to .stats.yml ([104b786](https://github.com/openai/openai-go/commit/104b7861bb025514999b143f7d1de45d2dab659f))
+* add request options to client tests ([#321](https://github.com/openai/openai-go/issues/321)) ([f5239ce](https://github.com/openai/openai-go/commit/f5239ceecf36835341eac5121ed1770020c4806a))
+* **api:** updates to supported Voice IDs ([#325](https://github.com/openai/openai-go/issues/325)) ([477727a](https://github.com/openai/openai-go/commit/477727a44b0fb72493c4749cc60171e0d30f98ec))
+* **docs:** improve security documentation ([#319](https://github.com/openai/openai-go/issues/319)) ([0271053](https://github.com/openai/openai-go/commit/027105363ab30ac3e189234908169faf94e0ca49))
+* fix typos ([#324](https://github.com/openai/openai-go/issues/324)) ([dba15f7](https://github.com/openai/openai-go/commit/dba15f74d63814ce16f778e1017a209a42f46179))
+
## 0.1.0-beta.2 (2025-03-22)
Full Changelog: [v0.1.0-beta.1...v0.1.0-beta.2](https://github.com/openai/openai-go/compare/v0.1.0-beta.1...v0.1.0-beta.2)
@@ -4,7 +4,7 @@ To set up the repository, run:
```sh
$ ./scripts/bootstrap
-$ ./scripts/build
+$ ./scripts/lint
```
This will install all the required dependencies and build the SDK.
@@ -227,15 +227,15 @@ var name *string = animal.GetName()
The old SDK had a function `param.Null[T]()` which could set `param.Field[T]` to `null`.
-The new SDK uses `param.NullOpt[T]()` for to set a `param.Opt[T]` to `null`,
-and `param.NullObj[T]()` to set a param struct `T` to `null`.
+The new SDK uses `param.Null[T]()` for to set a `param.Opt[T]` to `null`,
+but `param.NullStruct[T]()` to set a param struct `T` to `null`.
```diff
-- var nullObj param.Field[BarParam] = param.Null[BarParam]()
-+ var nullObj BarParam = param.NullObj[BarParam]()
-
- var nullPrimitive param.Field[int64] = param.Null[int64]()
-+ var nullPrimitive param.Opt[int64] = param.NullOpt[int64]()
++ var nullPrimitive param.Opt[int64] = param.Null[int64]()
+
+- var nullStruct param.Field[BarParam] = param.Null[BarParam]()
++ var nullStruct BarParam = param.NullStruct[BarParam]()
```
## Sending custom values
@@ -248,7 +248,7 @@ foo := FooParams{
A: param.String("hello"),
- B: param.Raw[string](12) // sending `12` instead of a string
}
-+ foo.WithExtraFields(map[string]any{
++ foo.SetExtraFields(map[string]any{
+ "B": 12,
+ })
```
@@ -257,20 +257,20 @@ foo := FooParams{
## Checking for presence of optional fields
-The `.IsNull()` method has been changed to `.IsPresent()` to better reflect its behavior.
+The `.IsNull()` method has been changed to `.Valid()` to better reflect its behavior.
```diff
- if !resp.Foo.JSON.Bar.IsNull() {
-+ if resp.Foo.JSON.Bar.IsPresent() {
++ if resp.Foo.JSON.Bar.Valid() {
println("bar is present:", resp.Foo.Bar)
}
```
-| Previous | New | Returns true for values |
-| -------------- | ------------------- | ----------------------- |
-| `.IsNull()` | `!.IsPresent()` | `null` or Omitted |
-| `.IsMissing()` | `.Raw() == ""` | Omitted |
-| | `.IsExplicitNull()` | `null` |
+| Previous | New | Returns true for values |
+| -------------- | ------------------------ | ----------------------- |
+| `.IsNull()` | `!.Valid()` | `null` or Omitted |
+| `.IsMissing()` | `.Raw() == resp.Omitted` | Omitted |
+| | `.Raw() == resp.Null` |
## Checking Raw JSON of a response
@@ -2,8 +2,8 @@
<a href="https://pkg.go.dev/github.com/openai/openai-go"><img src="https://pkg.go.dev/badge/github.com/openai/openai-go.svg" alt="Go Reference"></a>
-The OpenAI Go library provides convenient access to [the OpenAI REST
-API](https://platform.openai.com/docs) from applications written in Go. The full API of this library can be found in [api.md](api.md).
+The OpenAI Go library provides convenient access to the [OpenAI REST API](https://platform.openai.com/docs)
+from applications written in Go.
> [!WARNING]
> The latest version of this package uses a new design with significant breaking changes.
@@ -26,7 +26,7 @@ Or to pin the version:
<!-- x-release-please-start-version -->
```sh
-go get -u 'github.com/openai/openai-go@v0.1.0-beta.2'
+go get -u 'github.com/openai/openai-go@v1.8.2'
```
<!-- x-release-please-end -->
@@ -295,52 +295,59 @@ func main() {
The openai library uses the [`omitzero`](https://tip.golang.org/doc/go1.24#encodingjsonpkgencodingjson)
semantics from the Go 1.24+ `encoding/json` release for request fields.
-Required primitive fields (`int64`, `string`, etc.) feature the tag <code>\`json:...,required\`</code>. These
+Required primitive fields (`int64`, `string`, etc.) feature the tag <code>\`json:"...,required"\`</code>. These
fields are always serialized, even their zero values.
-Optional primitive types are wrapped in a `param.Opt[T]`. Use the provided constructors set `param.Opt[T]` fields such as `openai.String(string)`, `openai.Int(int64)`, etc.
+Optional primitive types are wrapped in a `param.Opt[T]`. These fields can be set with the provided constructors, `openai.String(string)`, `openai.Int(int64)`, etc.
-Optional primitives, maps, slices and structs and string enums (represented as `string`) always feature the
-tag <code>\`json:"...,omitzero"\`</code>. Their zero values are considered omitted.
+Any `param.Opt[T]`, map, slice, struct or string enum uses the
+tag <code>\`json:"...,omitzero"\`</code>. Its zero value is considered omitted.
-Any non-nil slice of length zero will serialize as an empty JSON array, `"[]"`. Similarly, any non-nil map with length zero with serialize as an empty JSON object, `"{}"`.
-
-To send `null` instead of an `param.Opt[T]`, use `param.NullOpt[T]()`.
-To send `null` instead of a struct, use `param.NullObj[T]()`, where `T` is a struct.
-To send a custom value instead of a struct, use `param.OverrideObj[T](value)`.
-
-To override request structs contain a `.WithExtraFields(map[string]any)` method which can be used to
-send non-conforming fields in the request body. Extra fields take higher precedence than normal
-fields.
+The `param.IsOmitted(any)` function can confirm the presence of any `omitzero` field.
```go
-params := FooParams{
- ID: "id_xxx", // required property
- Name: openai.String("hello"), // optional property
- Description: param.NullOpt[string](), // explicit null property
+p := openai.ExampleParams{
+ ID: "id_xxx", // required property
+ Name: openai.String("..."), // optional property
Point: openai.Point{
- X: 0, // required field will serialize as 0
+ X: 0, // required field will serialize as 0
Y: openai.Int(1), // optional field will serialize as 1
- // ... omitted non-required fields will not be serialized
- }),
+ // ... omitted non-required fields will not be serialized
+ },
Origin: openai.Origin{}, // the zero value of [Origin] is considered omitted
}
+```
+
+To send `null` instead of a `param.Opt[T]`, use `param.Null[T]()`.
+To send `null` instead of a struct `T`, use `param.NullStruct[T]()`.
+
+```go
+p.Name = param.Null[string]() // 'null' instead of string
+p.Point = param.NullStruct[Point]() // 'null' instead of struct
+
+param.IsNull(p.Name) // true
+param.IsNull(p.Point) // true
+```
+Request structs contain a `.SetExtraFields(map[string]any)` method which can send non-conforming
+fields in the request body. Extra fields overwrite any struct fields with a matching
+key. For security reasons, only use `SetExtraFields` with trusted data.
+
+To send a custom value instead of a struct, use `param.Override[T](value)`.
+
+```go
// In cases where the API specifies a given type,
-// but you want to send something else, use [WithExtraFields]:
-params.WithExtraFields(map[string]any{
+// but you want to send something else, use [SetExtraFields]:
+p.SetExtraFields(map[string]any{
"x": 0.01, // send "x" as a float instead of int
})
// Send a number instead of an object
-custom := param.OverrideObj[openai.FooParams](12)
+custom := param.Override[openai.FooParams](12)
```
-When available, use the `.IsPresent()` method to check if an optional parameter is not omitted or `null`.
-Otherwise, the `param.IsOmitted(any)` function can confirm the presence of any `omitzero` field.
-
### Request unions
Unions are represented as a struct with fields prefixed by "Of" for each of it's variants,
@@ -352,8 +359,8 @@ These methods return a mutable pointer to the underlying data, if present.
```go
// Only one field can be non-zero, use param.IsOmitted() to check if a field is set
type AnimalUnionParam struct {
- OfCat *Cat `json:",omitzero,inline`
- OfDog *Dog `json:",omitzero,inline`
+ OfCat *Cat `json:",omitzero,inline`
+ OfDog *Dog `json:",omitzero,inline`
}
animal := AnimalUnionParam{
@@ -373,34 +380,54 @@ if address := animal.GetOwner().GetAddress(); address != nil {
### Response objects
-All fields in response structs are value types (not pointers or wrappers).
+All fields in response structs are ordinary value types (not pointers or wrappers).
+Response structs also include a special `JSON` field containing metadata about
+each property.
+
+```go
+type Animal struct {
+ Name string `json:"name,nullable"`
+ Owners int `json:"owners"`
+ Age int `json:"age"`
+ JSON struct {
+ Name respjson.Field
+ Owner respjson.Field
+ Age respjson.Field
+ ExtraFields map[string]respjson.Field
+ } `json:"-"`
+}
+```
-If a given field is `null`, not present, or invalid, the corresponding field
-will simply be its zero value.
+To handle optional data, use the `.Valid()` method on the JSON field.
+`.Valid()` returns true if a field is not `null`, not present, or couldn't be marshaled.
-All response structs also include a special `JSON` field, containing more detailed
-information about each property, which you can use like so:
+If `.Valid()` is false, the corresponding field will simply be its zero value.
```go
-if res.Name == "" {
- // true if `"name"` was unmarshalled successfully
- res.JSON.Name.IsPresent()
-
- res.JSON.Name.IsExplicitNull() // true if `"name"` is explicitly null
- res.JSON.Name.Raw() == "" // true if `"name"` field does not exist
-
- // When the API returns data that cannot be coerced to the expected type:
- if !res.JSON.Name.IsPresent() && res.JSON.Name.Raw() != "" {
- raw := res.JSON.Name.Raw()
-
- legacyName := struct{
- First string `json:"first"`
- Last string `json:"last"`
- }{}
- json.Unmarshal([]byte(raw), &legacyName)
- name = legacyName.First + " " + legacyName.Last
- }
-}
+raw := `{"owners": 1, "name": null}`
+
+var res Animal
+json.Unmarshal([]byte(raw), &res)
+
+// Accessing regular fields
+
+res.Owners // 1
+res.Name // ""
+res.Age // 0
+
+// Optional field checks
+
+res.JSON.Owners.Valid() // true
+res.JSON.Name.Valid() // false
+res.JSON.Age.Valid() // false
+
+// Raw JSON values
+
+res.JSON.Owners.Raw() // "1"
+res.JSON.Name.Raw() == "null" // true
+res.JSON.Name.Raw() == respjson.Null // true
+res.JSON.Age.Raw() == "" // true
+res.JSON.Age.Raw() == respjson.Omitted // true
```
These `.JSON` structs also include an `ExtraFields` map containing
@@ -423,31 +450,27 @@ the properties but prefixed with `Of` and feature the tag `json:"...,inline"`.
```go
type AnimalUnion struct {
- OfString string `json:",inline"`
- Name string `json:"name"`
- Owner Person `json:"owner"`
+ // From variants [Dog], [Cat]
+ Owner Person `json:"owner"`
+ // From variant [Dog]
+ DogBreed string `json:"dog_breed"`
+ // From variant [Cat]
+ CatBreed string `json:"cat_breed"`
// ...
+
JSON struct {
- OfString resp.Field
- Name resp.Field
- Owner resp.Field
+ Owner respjson.Field
// ...
- }
+ } `json:"-"`
}
// If animal variant
-if animal.Owner.Address.JSON.ZipCode == "" {
+if animal.Owner.Address.ZipCode == "" {
panic("missing zip code")
}
-// If string variant
-if !animal.OfString == "" {
- panic("expected a name")
-}
-
// Switch on the variant
-switch variant := animalOrName.AsAny().(type) {
-case string:
+switch variant := animal.AsAny().(type) {
case Dog:
case Cat:
default:
@@ -476,6 +499,8 @@ client.Chat.Completions.New(context.TODO(), ...,
)
```
+The request option `option.WithDebugLog(nil)` may be helpful while debugging.
+
See the [full list of request options](https://pkg.go.dev/github.com/openai/openai-go/option).
### Pagination
@@ -527,7 +552,7 @@ To handle errors, we recommend that you use the `errors.As` pattern:
```go
_, err := client.FineTuning.Jobs.New(context.TODO(), openai.FineTuningJobNewParams{
- Model: "babbage-002",
+ Model: openai.FineTuningJobNewParamsModelBabbage002,
TrainingFile: "file-abc123",
})
if err != nil {
@@ -564,7 +589,7 @@ client.Chat.Completions.New(
},
},
}},
- Model: shared.ChatModelO3Mini,
+ Model: shared.ChatModelGPT4_1,
},
// This sets the per-retry timeout
option.WithRequestTimeout(20*time.Second),
@@ -581,30 +606,145 @@ The file name and content-type can be customized by implementing `Name() string`
string` on the run-time type of `io.Reader`. Note that `os.File` implements `Name() string`, so a
file returned by `os.Open` will be sent with the file name on disk.
-We also provide a helper `openai.FileParam(reader io.Reader, filename string, contentType string)`
+We also provide a helper `openai.File(reader io.Reader, filename string, contentType string)`
which can be used to wrap any `io.Reader` with the appropriate file name and content type.
```go
// A file from the file system
file, err := os.Open("input.jsonl")
openai.FileNewParams{
- File: openai.F[io.Reader](file),
+ File: file,
Purpose: openai.FilePurposeFineTune,
}
// A file from a string
openai.FileNewParams{
- File: openai.F[io.Reader](strings.NewReader("my file contents")),
+ File: strings.NewReader("my file contents"),
Purpose: openai.FilePurposeFineTune,
}
// With a custom filename and contentType
openai.FileNewParams{
- File: openai.FileParam(strings.NewReader(`{"hello": "foo"}`), "file.go", "application/json"),
+ File: openai.File(strings.NewReader(`{"hello": "foo"}`), "file.go", "application/json"),
Purpose: openai.FilePurposeFineTune,
}
```
+## Webhook Verification
+
+Verifying webhook signatures is _optional but encouraged_.
+
+For more information about webhooks, see [the API docs](https://platform.openai.com/docs/guides/webhooks).
+
+### Parsing webhook payloads
+
+For most use cases, you will likely want to verify the webhook and parse the payload at the same time. To achieve this, we provide the method `client.Webhooks.Unwrap()`, which parses a webhook request and verifies that it was sent by OpenAI. This method will return an error if the signature is invalid.
+
+Note that the `body` parameter should be the raw JSON bytes sent from the server (do not parse it first). The `Unwrap()` method will parse this JSON for you into an event object after verifying the webhook was sent from OpenAI.
+
+```go
+package main
+
+import (
+ "io"
+ "log"
+ "net/http"
+ "os"
+
+ "github.com/gin-gonic/gin"
+ "github.com/openai/openai-go"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/webhooks"
+)
+
+func main() {
+ client := openai.NewClient(
+ option.WithWebhookSecret(os.Getenv("OPENAI_WEBHOOK_SECRET")), // env var used by default; explicit here.
+ )
+
+ r := gin.Default()
+
+ r.POST("/webhook", func(c *gin.Context) {
+ body, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reading request body"})
+ return
+ }
+ defer c.Request.Body.Close()
+
+ webhookEvent, err := client.Webhooks.Unwrap(body, c.Request.Header)
+ if err != nil {
+ log.Printf("Invalid webhook signature: %v", err)
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid signature"})
+ return
+ }
+
+ switch event := webhookEvent.AsAny().(type) {
+ case webhooks.ResponseCompletedWebhookEvent:
+ log.Printf("Response completed: %+v", event.Data)
+ case webhooks.ResponseFailedWebhookEvent:
+ log.Printf("Response failed: %+v", event.Data)
+ default:
+ log.Printf("Unhandled event type: %T", event)
+ }
+
+ c.JSON(http.StatusOK, gin.H{"message": "ok"})
+ })
+
+ r.Run(":8000")
+}
+```
+
+### Verifying webhook payloads directly
+
+In some cases, you may want to verify the webhook separately from parsing the payload. If you prefer to handle these steps separately, we provide the method `client.Webhooks.VerifySignature()` to _only verify_ the signature of a webhook request. Like `Unwrap()`, this method will return an error if the signature is invalid.
+
+Note that the `body` parameter should be the raw JSON bytes sent from the server (do not parse it first). You will then need to parse the body after verifying the signature.
+
+```go
+package main
+
+import (
+ "encoding/json"
+ "io"
+ "log"
+ "net/http"
+ "os"
+
+ "github.com/gin-gonic/gin"
+ "github.com/openai/openai-go"
+ "github.com/openai/openai-go/option"
+)
+
+func main() {
+ client := openai.NewClient(
+ option.WithWebhookSecret(os.Getenv("OPENAI_WEBHOOK_SECRET")), // env var used by default; explicit here.
+ )
+
+ r := gin.Default()
+
+ r.POST("/webhook", func(c *gin.Context) {
+ body, err := io.ReadAll(c.Request.Body)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "Error reading request body"})
+ return
+ }
+ defer c.Request.Body.Close()
+
+ err = client.Webhooks.VerifySignature(body, c.Request.Header)
+ if err != nil {
+ log.Printf("Invalid webhook signature: %v", err)
+ c.JSON(http.StatusBadRequest, gin.H{"error": "invalid signature"})
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"message": "ok"})
+ })
+
+ r.Run(":8000")
+}
+```
+
### Retries
Certain errors will be automatically retried 2 times by default, with a short exponential backoff.
@@ -630,7 +770,7 @@ client.Chat.Completions.New(
},
},
}},
- Model: shared.ChatModelO3Mini,
+ Model: shared.ChatModelGPT4_1,
},
option.WithMaxRetries(5),
)
@@ -654,7 +794,7 @@ chatCompletion, err := client.Chat.Completions.New(
},
},
}},
- Model: shared.ChatModelO3Mini,
+ Model: shared.ChatModelGPT4_1,
},
option.WithResponseInto(&response),
)
@@ -681,7 +821,7 @@ To make requests to undocumented endpoints, you can use `client.Get`, `client.Po
var (
// params can be an io.Reader, a []byte, an encoding/json serializable object,
// or a "โฆParams" struct defined in this library.
- params map[string]interface{}
+ params map[string]any
// result can be an []byte, *http.Response, a encoding/json deserializable object,
// or a model defined in this library.
@@ -16,13 +16,13 @@ before making any information public.
## Reporting Non-SDK Related Security Issues
If you encounter security issues that are not directly related to SDKs but pertain to the services
-or products provided by OpenAI please follow the respective company's security reporting guidelines.
+or products provided by OpenAI, please follow the respective company's security reporting guidelines.
### OpenAI Terms and Policies
Our Security Policy can be found at [Security Policy URL](https://openai.com/policies/coordinated-vulnerability-disclosure-policy).
-Please contact disclosure@openai.com for any questions or concerns regarding security of our services.
+Please contact disclosure@openai.com for any questions or concerns regarding the security of our services.
---
@@ -5,7 +5,6 @@ package openai
import (
"github.com/openai/openai-go/internal/apierror"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
"github.com/openai/openai-go/shared"
)
@@ -20,6 +19,36 @@ type Error = apierror.Error
// This is an alias to an internal type.
type ChatModel = shared.ChatModel
+// Equals "gpt-4.1"
+const ChatModelGPT4_1 = shared.ChatModelGPT4_1
+
+// Equals "gpt-4.1-mini"
+const ChatModelGPT4_1Mini = shared.ChatModelGPT4_1Mini
+
+// Equals "gpt-4.1-nano"
+const ChatModelGPT4_1Nano = shared.ChatModelGPT4_1Nano
+
+// Equals "gpt-4.1-2025-04-14"
+const ChatModelGPT4_1_2025_04_14 = shared.ChatModelGPT4_1_2025_04_14
+
+// Equals "gpt-4.1-mini-2025-04-14"
+const ChatModelGPT4_1Mini2025_04_14 = shared.ChatModelGPT4_1Mini2025_04_14
+
+// Equals "gpt-4.1-nano-2025-04-14"
+const ChatModelGPT4_1Nano2025_04_14 = shared.ChatModelGPT4_1Nano2025_04_14
+
+// Equals "o4-mini"
+const ChatModelO4Mini = shared.ChatModelO4Mini
+
+// Equals "o4-mini-2025-04-16"
+const ChatModelO4Mini2025_04_16 = shared.ChatModelO4Mini2025_04_16
+
+// Equals "o3"
+const ChatModelO3 = shared.ChatModelO3
+
+// Equals "o3-2025-04-16"
+const ChatModelO3_2025_04_16 = shared.ChatModelO3_2025_04_16
+
// Equals "o3-mini"
const ChatModelO3Mini = shared.ChatModelO3Mini
@@ -65,6 +94,9 @@ const ChatModelGPT4oAudioPreview2024_10_01 = shared.ChatModelGPT4oAudioPreview20
// Equals "gpt-4o-audio-preview-2024-12-17"
const ChatModelGPT4oAudioPreview2024_12_17 = shared.ChatModelGPT4oAudioPreview2024_12_17
+// Equals "gpt-4o-audio-preview-2025-06-03"
+const ChatModelGPT4oAudioPreview2025_06_03 = shared.ChatModelGPT4oAudioPreview2025_06_03
+
// Equals "gpt-4o-mini-audio-preview"
const ChatModelGPT4oMiniAudioPreview = shared.ChatModelGPT4oMiniAudioPreview
@@ -86,6 +118,9 @@ const ChatModelGPT4oMiniSearchPreview2025_03_11 = shared.ChatModelGPT4oMiniSearc
// Equals "chatgpt-4o-latest"
const ChatModelChatgpt4oLatest = shared.ChatModelChatgpt4oLatest
+// Equals "codex-mini-latest"
+const ChatModelCodexMiniLatest = shared.ChatModelCodexMiniLatest
+
// Equals "gpt-4o-mini"
const ChatModelGPT4oMini = shared.ChatModelGPT4oMini
@@ -254,16 +289,6 @@ type FunctionParameters = shared.FunctionParameters
// This is an alias to an internal type.
type Metadata = shared.Metadata
-// Set of 16 key-value pairs that can be attached to an object. This can be useful
-// for storing additional information about the object in a structured format, and
-// querying for objects via API or the dashboard.
-//
-// Keys are strings with a maximum length of 64 characters. Values are strings with
-// a maximum length of 512 characters.
-//
-// This is an alias to an internal type.
-type MetadataParam = shared.MetadataParam
-
// **o-series models only**
//
// Configuration options for
@@ -272,21 +297,40 @@ type MetadataParam = shared.MetadataParam
// This is an alias to an internal type.
type Reasoning = shared.Reasoning
-// **computer_use_preview only**
+// **Deprecated:** use `summary` instead.
//
// A summary of the reasoning performed by the model. This can be useful for
-// debugging and understanding the model's reasoning process. One of `concise` or
-// `detailed`.
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
//
// This is an alias to an internal type.
type ReasoningGenerateSummary = shared.ReasoningGenerateSummary
+// Equals "auto"
+const ReasoningGenerateSummaryAuto = shared.ReasoningGenerateSummaryAuto
+
// Equals "concise"
const ReasoningGenerateSummaryConcise = shared.ReasoningGenerateSummaryConcise
// Equals "detailed"
const ReasoningGenerateSummaryDetailed = shared.ReasoningGenerateSummaryDetailed
+// A summary of the reasoning performed by the model. This can be useful for
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
+//
+// This is an alias to an internal type.
+type ReasoningSummary = shared.ReasoningSummary
+
+// Equals "auto"
+const ReasoningSummaryAuto = shared.ReasoningSummaryAuto
+
+// Equals "concise"
+const ReasoningSummaryConcise = shared.ReasoningSummaryConcise
+
+// Equals "detailed"
+const ReasoningSummaryDetailed = shared.ReasoningSummaryDetailed
+
// **o-series models only**
//
// Configuration options for
@@ -371,18 +415,26 @@ const ResponsesModelO1Pro = shared.ResponsesModelO1Pro
// Equals "o1-pro-2025-03-19"
const ResponsesModelO1Pro2025_03_19 = shared.ResponsesModelO1Pro2025_03_19
+// Equals "o3-pro"
+const ResponsesModelO3Pro = shared.ResponsesModelO3Pro
+
+// Equals "o3-pro-2025-06-10"
+const ResponsesModelO3Pro2025_06_10 = shared.ResponsesModelO3Pro2025_06_10
+
+// Equals "o3-deep-research"
+const ResponsesModelO3DeepResearch = shared.ResponsesModelO3DeepResearch
+
+// Equals "o3-deep-research-2025-06-26"
+const ResponsesModelO3DeepResearch2025_06_26 = shared.ResponsesModelO3DeepResearch2025_06_26
+
+// Equals "o4-mini-deep-research"
+const ResponsesModelO4MiniDeepResearch = shared.ResponsesModelO4MiniDeepResearch
+
+// Equals "o4-mini-deep-research-2025-06-26"
+const ResponsesModelO4MiniDeepResearch2025_06_26 = shared.ResponsesModelO4MiniDeepResearch2025_06_26
+
// Equals "computer-use-preview"
const ResponsesModelComputerUsePreview = shared.ResponsesModelComputerUsePreview
// Equals "computer-use-preview-2025-03-11"
const ResponsesModelComputerUsePreview2025_03_11 = shared.ResponsesModelComputerUsePreview2025_03_11
-
-func toParam[T comparable](value T, meta resp.Field) param.Opt[T] {
- if meta.IsPresent() {
- return param.NewOpt(value)
- }
- if meta.IsExplicitNull() {
- return param.NullOpt[T]()
- }
- return param.Opt[T]{}
-}
@@ -5,7 +5,7 @@
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#CompoundFilterParam">CompoundFilterParam</a>
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#FunctionDefinitionParam">FunctionDefinitionParam</a>
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#FunctionParameters">FunctionParameters</a>
-- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#MetadataParam">MetadataParam</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#Metadata">Metadata</a>
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#ReasoningParam">ReasoningParam</a>
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#ReasoningEffort">ReasoningEffort</a>
- <a href="https://pkg.go.dev/github.com/openai/openai-go/shared">shared</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/shared#ResponseFormatJSONObjectParam">ResponseFormatJSONObjectParam</a>
@@ -221,6 +221,26 @@ Methods:
# FineTuning
+## Methods
+
+Params Types:
+
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#DpoHyperparameters">DpoHyperparameters</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#DpoMethodParam">DpoMethodParam</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#ReinforcementHyperparameters">ReinforcementHyperparameters</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#ReinforcementMethodParam">ReinforcementMethodParam</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#SupervisedHyperparameters">SupervisedHyperparameters</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#SupervisedMethodParam">SupervisedMethodParam</a>
+
+Response Types:
+
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#DpoHyperparametersResp">DpoHyperparametersResp</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#DpoMethod">DpoMethod</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#ReinforcementHyperparametersResp">ReinforcementHyperparametersResp</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#ReinforcementMethod">ReinforcementMethod</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#SupervisedHyperparametersResp">SupervisedHyperparametersResp</a>
+- <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#SupervisedMethod">SupervisedMethod</a>
+
## Jobs
Response Types:
@@ -237,6 +257,8 @@ Methods:
- <code title="get /fine_tuning/jobs">client.FineTuning.Jobs.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobService.List">List</a>(ctx <a href="https://pkg.go.dev/context">context</a>.<a href="https://pkg.go.dev/context#Context">Context</a>, query <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobListParams">FineTuningJobListParams</a>) (<a href="https://pkg.go.dev/github.com/openai/openai-go/packages/pagination">pagination</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/packages/pagination#CursorPage">CursorPage</a>[<a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJob">FineTuningJob</a>], <a href="https://pkg.go.dev/builtin#error">error</a>)</code>
- <code title="post /fine_tuning/jobs/{fine_tuning_job_id}/cancel">client.FineTuning.Jobs.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobService.Cancel">Cancel</a>(ctx <a href="https://pkg.go.dev/context">context</a>.<a href="https://pkg.go.dev/context#Context">Context</a>, fineTuningJobID <a href="https://pkg.go.dev/builtin#string">string</a>) (<a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJob">FineTuningJob</a>, <a href="https://pkg.go.dev/builtin#error">error</a>)</code>
- <code title="get /fine_tuning/jobs/{fine_tuning_job_id}/events">client.FineTuning.Jobs.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobService.ListEvents">ListEvents</a>(ctx <a href="https://pkg.go.dev/context">context</a>.<a href="https://pkg.go.dev/context#Context">Context</a>, fineTuningJobID <a href="https://pkg.go.dev/builtin#string">string</a>, query <a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobListEventsParams">FineTuningJobListEventsParams</a>) (<a href="https://pkg.go.dev/github.com/openai/openai-go/packages/pagination">pagination</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go/packages/pagination#CursorPage">CursorPage</a>[<a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobEvent">FineTuningJobEvent</a>], <a href="https://pkg.go.dev/builtin#error">error</a>)</code>
+- <code title="post /fine_tuning/jobs/{fine_tuning_job_id}/pause">client.FineTuning.Jobs.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobService.Pause">Pause</a>(ctx <a href="https://pkg.go.dev/context">context</a>.<a href="https://pkg.go.dev/context#Context">Context</a>, fineTuningJobID <a href="https://pkg.go.dev/builtin#string">string</a>) (<a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJob">FineTuningJob</a>, <a href="https://pkg.go.dev/builtin#error">error</a>)</code>
+- <code title="post /fine_tuning/jobs/{fine_tuning_job_id}/resume">client.FineTuning.Jobs.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJobService.Resume">Resume</a>(ctx <a href="https://pkg.go.dev/context">context</a>.<a href="https://pkg.go.dev/context#Context">Context</a>, fineTuningJobID <a href="https://pkg.go.dev/builtin#string">string</a>) (<a href="https://pkg.go.dev/github.com/openai/openai-go">openai</a>.<a href="https://pkg.go.dev/github.com/openai/openai-go#FineTuningJob">FineTuningJob</a>, <a href="https://pkg.go.dev/builtin#error">error</a>)</code>
### Checkpoints
@@ -248,6 +270,58 @@ Methods:
@@ -6,6 +6,7 @@ import (
"context"
"net/http"
+ "github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
@@ -54,12 +55,9 @@ type AudioSpeechNewParams struct {
// `tts-1`, `tts-1-hd` or `gpt-4o-mini-tts`.
Model SpeechModel `json:"model,omitzero,required"`
// The voice to use when generating the audio. Supported voices are `alloy`, `ash`,
- // `coral`, `echo`, `fable`, `onyx`, `nova`, `sage` and `shimmer`. Previews of the
- // voices are available in the
+ // `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`, `shimmer`, and
+ // `verse`. Previews of the voices are available in the
// [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
- //
- // Any of "alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage",
- // "shimmer".
Voice AudioSpeechNewParamsVoice `json:"voice,omitzero,required"`
// Control the voice of your generated audio with additional instructions. Does not
// work with `tts-1` or `tts-1-hd`.
@@ -72,27 +70,32 @@ type AudioSpeechNewParams struct {
//
// Any of "mp3", "opus", "aac", "flac", "wav", "pcm".
ResponseFormat AudioSpeechNewParamsResponseFormat `json:"response_format,omitzero"`
+ // The format to stream the audio in. Supported formats are `sse` and `audio`.
+ // `sse` is not supported for `tts-1` or `tts-1-hd`.
+ //
+ // Any of "sse", "audio".
+ StreamFormat AudioSpeechNewParamsStreamFormat `json:"stream_format,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AudioSpeechNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r AudioSpeechNewParams) MarshalJSON() (data []byte, err error) {
type shadow AudioSpeechNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *AudioSpeechNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The voice to use when generating the audio. Supported voices are `alloy`, `ash`,
-// `coral`, `echo`, `fable`, `onyx`, `nova`, `sage` and `shimmer`. Previews of the
-// voices are available in the
+// `ballad`, `coral`, `echo`, `fable`, `onyx`, `nova`, `sage`, `shimmer`, and
+// `verse`. Previews of the voices are available in the
// [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech#voice-options).
type AudioSpeechNewParamsVoice string
const (
AudioSpeechNewParamsVoiceAlloy AudioSpeechNewParamsVoice = "alloy"
AudioSpeechNewParamsVoiceAsh AudioSpeechNewParamsVoice = "ash"
+ AudioSpeechNewParamsVoiceBallad AudioSpeechNewParamsVoice = "ballad"
AudioSpeechNewParamsVoiceCoral AudioSpeechNewParamsVoice = "coral"
AudioSpeechNewParamsVoiceEcho AudioSpeechNewParamsVoice = "echo"
AudioSpeechNewParamsVoiceFable AudioSpeechNewParamsVoice = "fable"
@@ -100,6 +103,7 @@ const (
AudioSpeechNewParamsVoiceNova AudioSpeechNewParamsVoice = "nova"
AudioSpeechNewParamsVoiceSage AudioSpeechNewParamsVoice = "sage"
AudioSpeechNewParamsVoiceShimmer AudioSpeechNewParamsVoice = "shimmer"
+ AudioSpeechNewParamsVoiceVerse AudioSpeechNewParamsVoice = "verse"
)
// The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`,
@@ -114,3 +118,12 @@ const (
AudioSpeechNewParamsResponseFormatWAV AudioSpeechNewParamsResponseFormat = "wav"
AudioSpeechNewParamsResponseFormatPCM AudioSpeechNewParamsResponseFormat = "pcm"
)
+
+// The format to stream the audio in. Supported formats are `sse` and `audio`.
+// `sse` is not supported for `tts-1` or `tts-1-hd`.
+type AudioSpeechNewParamsStreamFormat string
+
+const (
+ AudioSpeechNewParamsStreamFormatSSE AudioSpeechNewParamsStreamFormat = "sse"
+ AudioSpeechNewParamsStreamFormatAudio AudioSpeechNewParamsStreamFormat = "audio"
+)
@@ -15,7 +15,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared/constant"
)
@@ -54,7 +54,9 @@ func (r *AudioTranscriptionService) NewStreaming(ctx context.Context, body Audio
err error
)
opts = append(r.Options[:], opts...)
- opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...)
+ body.SetExtraFields(map[string]any{
+ "stream": "true",
+ })
path := "audio/transcriptions"
err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &raw, opts...)
return ssestream.NewStream[TranscriptionStreamEventUnion](ssestream.NewDecoder(raw), err)
@@ -69,12 +71,14 @@ type Transcription struct {
// models `gpt-4o-transcribe` and `gpt-4o-mini-transcribe` if `logprobs` is added
// to the `include` array.
Logprobs []TranscriptionLogprob `json:"logprobs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // Token usage statistics for the request.
+ Usage TranscriptionUsageUnion `json:"usage"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Logprobs resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Logprobs respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -92,13 +96,12 @@ type TranscriptionLogprob struct {
Bytes []float64 `json:"bytes"`
// The log probability of the token.
Logprob float64 `json:"logprob"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Token resp.Field
- Bytes resp.Field
- Logprob resp.Field
- ExtraFields map[string]resp.Field
+ Token respjson.Field
+ Bytes respjson.Field
+ Logprob respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -109,6 +112,153 @@ func (r *TranscriptionLogprob) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+// TranscriptionUsageUnion contains all possible properties and values from
+// [TranscriptionUsageTokens], [TranscriptionUsageDuration].
+//
+// Use the [TranscriptionUsageUnion.AsAny] method to switch on the variant.
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+type TranscriptionUsageUnion struct {
+ // This field is from variant [TranscriptionUsageTokens].
+ InputTokens int64 `json:"input_tokens"`
+ // This field is from variant [TranscriptionUsageTokens].
+ OutputTokens int64 `json:"output_tokens"`
+ // This field is from variant [TranscriptionUsageTokens].
+ TotalTokens int64 `json:"total_tokens"`
+ // Any of "tokens", "duration".
+ Type string `json:"type"`
+ // This field is from variant [TranscriptionUsageTokens].
+ InputTokenDetails TranscriptionUsageTokensInputTokenDetails `json:"input_token_details"`
+ // This field is from variant [TranscriptionUsageDuration].
+ Duration float64 `json:"duration"`
+ JSON struct {
+ InputTokens respjson.Field
+ OutputTokens respjson.Field
+ TotalTokens respjson.Field
+ Type respjson.Field
+ InputTokenDetails respjson.Field
+ Duration respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// anyTranscriptionUsage is implemented by each variant of
+// [TranscriptionUsageUnion] to add type safety for the return type of
+// [TranscriptionUsageUnion.AsAny]
+type anyTranscriptionUsage interface {
+ implTranscriptionUsageUnion()
+}
+
+func (TranscriptionUsageTokens) implTranscriptionUsageUnion() {}
+func (TranscriptionUsageDuration) implTranscriptionUsageUnion() {}
+
+// Use the following switch statement to find the correct variant
+//
+// switch variant := TranscriptionUsageUnion.AsAny().(type) {
+// case openai.TranscriptionUsageTokens:
+// case openai.TranscriptionUsageDuration:
+// default:
+// fmt.Errorf("no variant present")
+// }
+func (u TranscriptionUsageUnion) AsAny() anyTranscriptionUsage {
+ switch u.Type {
+ case "tokens":
+ return u.AsTokens()
+ case "duration":
+ return u.AsDuration()
+ }
+ return nil
+}
+
+func (u TranscriptionUsageUnion) AsTokens() (v TranscriptionUsageTokens) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u TranscriptionUsageUnion) AsDuration() (v TranscriptionUsageDuration) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u TranscriptionUsageUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *TranscriptionUsageUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Usage statistics for models billed by token usage.
+type TranscriptionUsageTokens struct {
+ // Number of input tokens billed for this request.
+ InputTokens int64 `json:"input_tokens,required"`
+ // Number of output tokens generated.
+ OutputTokens int64 `json:"output_tokens,required"`
+ // Total number of tokens used (input + output).
+ TotalTokens int64 `json:"total_tokens,required"`
+ // The type of the usage object. Always `tokens` for this variant.
+ Type constant.Tokens `json:"type,required"`
+ // Details about the input tokens billed for this request.
+ InputTokenDetails TranscriptionUsageTokensInputTokenDetails `json:"input_token_details"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ InputTokens respjson.Field
+ OutputTokens respjson.Field
+ TotalTokens respjson.Field
+ Type respjson.Field
+ InputTokenDetails respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TranscriptionUsageTokens) RawJSON() string { return r.JSON.raw }
+func (r *TranscriptionUsageTokens) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Details about the input tokens billed for this request.
+type TranscriptionUsageTokensInputTokenDetails struct {
+ // Number of audio tokens billed for this request.
+ AudioTokens int64 `json:"audio_tokens"`
+ // Number of text tokens billed for this request.
+ TextTokens int64 `json:"text_tokens"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ AudioTokens respjson.Field
+ TextTokens respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TranscriptionUsageTokensInputTokenDetails) RawJSON() string { return r.JSON.raw }
+func (r *TranscriptionUsageTokensInputTokenDetails) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Usage statistics for models billed by audio input duration.
+type TranscriptionUsageDuration struct {
+ // Duration of the input audio in seconds.
+ Duration float64 `json:"duration,required"`
+ // The type of the usage object. Always `duration` for this variant.
+ Type constant.Duration `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Duration respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TranscriptionUsageDuration) RawJSON() string { return r.JSON.raw }
+func (r *TranscriptionUsageDuration) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
type TranscriptionInclude string
const (
@@ -131,24 +281,37 @@ type TranscriptionStreamEventUnion struct {
Logprobs TranscriptionStreamEventUnionLogprobs `json:"logprobs"`
// This field is from variant [TranscriptionTextDoneEvent].
Text string `json:"text"`
- JSON struct {
- Delta resp.Field
- Type resp.Field
- Logprobs resp.Field
- Text resp.Field
+ // This field is from variant [TranscriptionTextDoneEvent].
+ Usage TranscriptionTextDoneEventUsage `json:"usage"`
+ JSON struct {
+ Delta respjson.Field
+ Type respjson.Field
+ Logprobs respjson.Field
+ Text respjson.Field
+ Usage respjson.Field
raw string
} `json:"-"`
}
+// anyTranscriptionStreamEvent is implemented by each variant of
+// [TranscriptionStreamEventUnion] to add type safety for the return type of
+// [TranscriptionStreamEventUnion.AsAny]
+type anyTranscriptionStreamEvent interface {
+ implTranscriptionStreamEventUnion()
+}
+
+func (TranscriptionTextDeltaEvent) implTranscriptionStreamEventUnion() {}
+func (TranscriptionTextDoneEvent) implTranscriptionStreamEventUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := TranscriptionStreamEventUnion.AsAny().(type) {
-// case TranscriptionTextDeltaEvent:
-// case TranscriptionTextDoneEvent:
+// case openai.TranscriptionTextDeltaEvent:
+// case openai.TranscriptionTextDoneEvent:
// default:
// fmt.Errorf("no variant present")
// }
-func (u TranscriptionStreamEventUnion) AsAny() any {
+func (u TranscriptionStreamEventUnion) AsAny() anyTranscriptionStreamEvent {
switch u.Type {
case "transcript.text.delta":
return u.AsTranscriptTextDelta()
@@ -193,8 +356,8 @@ type TranscriptionStreamEventUnionLogprobs struct {
// [[]TranscriptionTextDoneEventLogprob] instead of an object.
OfTranscriptionTextDoneEventLogprobs []TranscriptionTextDoneEventLogprob `json:",inline"`
JSON struct {
- OfTranscriptionTextDeltaEventLogprobs resp.Field
- OfTranscriptionTextDoneEventLogprobs resp.Field
+ OfTranscriptionTextDeltaEventLogprobs respjson.Field
+ OfTranscriptionTextDoneEventLogprobs respjson.Field
raw string
} `json:"-"`
}
@@ -216,13 +379,12 @@ type TranscriptionTextDeltaEvent struct {
// [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription)
// with the `include[]` parameter set to `logprobs`.
Logprobs []TranscriptionTextDeltaEventLogprob `json:"logprobs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Delta resp.Field
- Type resp.Field
- Logprobs resp.Field
- ExtraFields map[string]resp.Field
+ Delta respjson.Field
+ Type respjson.Field
+ Logprobs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -237,16 +399,15 @@ type TranscriptionTextDeltaEventLogprob struct {
// The token that was used to generate the log probability.
Token string `json:"token"`
// The bytes that were used to generate the log probability.
- Bytes []interface{} `json:"bytes"`
+ Bytes []int64 `json:"bytes"`
// The log probability of the token.
Logprob float64 `json:"logprob"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Token resp.Field
- Bytes resp.Field
- Logprob resp.Field
- ExtraFields map[string]resp.Field
+ Token respjson.Field
+ Bytes respjson.Field
+ Logprob respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -271,13 +432,15 @@ type TranscriptionTextDoneEvent struct {
// [create a transcription](https://platform.openai.com/docs/api-reference/audio/create-transcription)
// with the `include[]` parameter set to `logprobs`.
Logprobs []TranscriptionTextDoneEventLogprob `json:"logprobs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // Usage statistics for models billed by token usage.
+ Usage TranscriptionTextDoneEventUsage `json:"usage"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Type resp.Field
- Logprobs resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Type respjson.Field
+ Logprobs respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -292,16 +455,15 @@ type TranscriptionTextDoneEventLogprob struct {
// The token that was used to generate the log probability.
Token string `json:"token"`
// The bytes that were used to generate the log probability.
- Bytes []interface{} `json:"bytes"`
+ Bytes []int64 `json:"bytes"`
// The log probability of the token.
Logprob float64 `json:"logprob"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Token resp.Field
- Bytes resp.Field
- Logprob resp.Field
- ExtraFields map[string]resp.Field
+ Token respjson.Field
+ Bytes respjson.Field
+ Logprob respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -312,10 +474,61 @@ func (r *TranscriptionTextDoneEventLogprob) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+// Usage statistics for models billed by token usage.
+type TranscriptionTextDoneEventUsage struct {
+ // Number of input tokens billed for this request.
+ InputTokens int64 `json:"input_tokens,required"`
+ // Number of output tokens generated.
+ OutputTokens int64 `json:"output_tokens,required"`
+ // Total number of tokens used (input + output).
+ TotalTokens int64 `json:"total_tokens,required"`
+ // The type of the usage object. Always `tokens` for this variant.
+ Type constant.Tokens `json:"type,required"`
+ // Details about the input tokens billed for this request.
+ InputTokenDetails TranscriptionTextDoneEventUsageInputTokenDetails `json:"input_token_details"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ InputTokens respjson.Field
+ OutputTokens respjson.Field
+ TotalTokens respjson.Field
+ Type respjson.Field
+ InputTokenDetails respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TranscriptionTextDoneEventUsage) RawJSON() string { return r.JSON.raw }
+func (r *TranscriptionTextDoneEventUsage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Details about the input tokens billed for this request.
+type TranscriptionTextDoneEventUsageInputTokenDetails struct {
+ // Number of audio tokens billed for this request.
+ AudioTokens int64 `json:"audio_tokens"`
+ // Number of text tokens billed for this request.
+ TextTokens int64 `json:"text_tokens"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ AudioTokens respjson.Field
+ TextTokens respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TranscriptionTextDoneEventUsageInputTokenDetails) RawJSON() string { return r.JSON.raw }
+func (r *TranscriptionTextDoneEventUsageInputTokenDetails) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
type AudioTranscriptionNewParams struct {
// The audio file object (not file name) to transcribe, in one of these formats:
// flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
- File io.Reader `json:"file,required" format:"binary"`
+ File io.Reader `json:"file,omitzero,required" format:"binary"`
// ID of the model to use. The options are `gpt-4o-transcribe`,
// `gpt-4o-mini-transcribe`, and `whisper-1` (which is powered by our open source
// Whisper V2 model).
@@ -335,6 +548,11 @@ type AudioTranscriptionNewParams struct {
// [log probability](https://en.wikipedia.org/wiki/Log_probability) to
// automatically increase the temperature until certain thresholds are hit.
Temperature param.Opt[float64] `json:"temperature,omitzero"`
+ // Controls how the audio is cut into chunks. When set to `"auto"`, the server
+ // first normalizes loudness and then uses voice activity detection (VAD) to choose
+ // boundaries. `server_vad` object can be provided to tweak VAD detection
+ // parameters manually. If unset, the audio is transcribed as a single block.
+ ChunkingStrategy AudioTranscriptionNewParamsChunkingStrategyUnion `json:"chunking_strategy,omitzero"`
// Additional information to include in the transcription response. `logprobs` will
// return the log probabilities of the tokens in the response to understand the
// model's confidence in the transcription. `logprobs` only works with
@@ -352,18 +570,19 @@ type AudioTranscriptionNewParams struct {
// Either or both of these options are supported: `word`, or `segment`. Note: There
// is no additional latency for segment timestamps, but generating word timestamps
// incurs additional latency.
+ //
+ // Any of "word", "segment".
TimestampGranularities []string `json:"timestamp_granularities,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AudioTranscriptionNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r AudioTranscriptionNewParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -374,3 +593,62 @@ func (r AudioTranscriptionNewParams) MarshalMultipart() (data []byte, contentTyp
}
return buf.Bytes(), writer.FormDataContentType(), nil
}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type AudioTranscriptionNewParamsChunkingStrategyUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfAudioTranscriptionNewsChunkingStrategyVadConfig *AudioTranscriptionNewParamsChunkingStrategyVadConfig `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u AudioTranscriptionNewParamsChunkingStrategyUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfAudioTranscriptionNewsChunkingStrategyVadConfig)
+}
+func (u *AudioTranscriptionNewParamsChunkingStrategyUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *AudioTranscriptionNewParamsChunkingStrategyUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfAudioTranscriptionNewsChunkingStrategyVadConfig) {
+ return u.OfAudioTranscriptionNewsChunkingStrategyVadConfig
+ }
+ return nil
+}
+
+// The property Type is required.
+type AudioTranscriptionNewParamsChunkingStrategyVadConfig struct {
+ // Must be set to `server_vad` to enable manual chunking using server side VAD.
+ //
+ // Any of "server_vad".
+ Type string `json:"type,omitzero,required"`
+ // Amount of audio to include before the VAD detected speech (in milliseconds).
+ PrefixPaddingMs param.Opt[int64] `json:"prefix_padding_ms,omitzero"`
+ // Duration of silence to detect speech stop (in milliseconds). With shorter values
+ // the model will respond more quickly, but may jump in on short pauses from the
+ // user.
+ SilenceDurationMs param.Opt[int64] `json:"silence_duration_ms,omitzero"`
+ // Sensitivity threshold (0.0 to 1.0) for voice activity detection. A higher
+ // threshold will require louder audio to activate the model, and thus might
+ // perform better in noisy environments.
+ Threshold param.Opt[float64] `json:"threshold,omitzero"`
+ paramObj
+}
+
+func (r AudioTranscriptionNewParamsChunkingStrategyVadConfig) MarshalJSON() (data []byte, err error) {
+ type shadow AudioTranscriptionNewParamsChunkingStrategyVadConfig
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *AudioTranscriptionNewParamsChunkingStrategyVadConfig) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+func init() {
+ apijson.RegisterFieldValidator[AudioTranscriptionNewParamsChunkingStrategyVadConfig](
+ "type", "server_vad",
+ )
+}
@@ -14,7 +14,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
)
// AudioTranslationService contains methods and other services that help with
@@ -46,11 +46,10 @@ func (r *AudioTranslationService) New(ctx context.Context, body AudioTranslation
type Translation struct {
Text string `json:"text,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -64,7 +63,7 @@ func (r *Translation) UnmarshalJSON(data []byte) error {
type AudioTranslationNewParams struct {
// The audio file object (not file name) translate, in one of these formats: flac,
// mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
- File io.Reader `json:"file,required" format:"binary"`
+ File io.Reader `json:"file,omitzero,required" format:"binary"`
// ID of the model to use. Only `whisper-1` (which is powered by our open source
// Whisper V2 model) is currently available.
Model AudioModel `json:"model,omitzero,required"`
@@ -87,14 +86,13 @@ type AudioTranslationNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AudioTranslationNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r AudioTranslationNewParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -15,7 +15,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
)
@@ -143,30 +143,29 @@ type Batch struct {
OutputFileID string `json:"output_file_id"`
// The request counts for different statuses within the batch.
RequestCounts BatchRequestCounts `json:"request_counts"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CompletionWindow resp.Field
- CreatedAt resp.Field
- Endpoint resp.Field
- InputFileID resp.Field
- Object resp.Field
- Status resp.Field
- CancelledAt resp.Field
- CancellingAt resp.Field
- CompletedAt resp.Field
- ErrorFileID resp.Field
- Errors resp.Field
- ExpiredAt resp.Field
- ExpiresAt resp.Field
- FailedAt resp.Field
- FinalizingAt resp.Field
- InProgressAt resp.Field
- Metadata resp.Field
- OutputFileID resp.Field
- RequestCounts resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CompletionWindow respjson.Field
+ CreatedAt respjson.Field
+ Endpoint respjson.Field
+ InputFileID respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ CancelledAt respjson.Field
+ CancellingAt respjson.Field
+ CompletedAt respjson.Field
+ ErrorFileID respjson.Field
+ Errors respjson.Field
+ ExpiredAt respjson.Field
+ ExpiresAt respjson.Field
+ FailedAt respjson.Field
+ FinalizingAt respjson.Field
+ InProgressAt respjson.Field
+ Metadata respjson.Field
+ OutputFileID respjson.Field
+ RequestCounts respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -195,12 +194,11 @@ type BatchErrors struct {
Data []BatchError `json:"data"`
// The object type, which is always `list`.
Object string `json:"object"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -220,14 +218,13 @@ type BatchError struct {
Message string `json:"message"`
// The name of the parameter that caused the error, if applicable.
Param string `json:"param,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Line resp.Field
- Message resp.Field
- Param resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Line respjson.Field
+ Message respjson.Field
+ Param respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -246,13 +243,12 @@ type BatchRequestCounts struct {
Failed int64 `json:"failed,required"`
// Total number of requests in the batch.
Total int64 `json:"total,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Completed resp.Field
- Failed resp.Field
- Total resp.Field
- ExtraFields map[string]resp.Field
+ Completed respjson.Field
+ Failed respjson.Field
+ Total respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -293,18 +289,17 @@ type BatchNewParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BatchNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BatchNewParams) MarshalJSON() (data []byte, err error) {
type shadow BatchNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BatchNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The time frame within which the batch should be processed. Currently only `24h`
// is supported.
@@ -339,12 +334,8 @@ type BatchListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BatchListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [BatchListParams]'s query parameters as `url.Values`.
-func (r BatchListParams) URLQuery() (v url.Values) {
+func (r BatchListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -15,7 +15,8 @@ import (
type BetaService struct {
Options []option.RequestOption
Assistants BetaAssistantService
- Threads BetaThreadService
+ // Deprecated: The Assistants API is deprecated in favor of the Responses API
+ Threads BetaThreadService
}
// NewBetaService generates a new service that applies the given options to each
@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/url"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
@@ -17,10 +16,9 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// BetaAssistantService contains methods and other services that help with
@@ -181,23 +179,22 @@ type Assistant struct {
//
// We generally recommend altering this or temperature but not both.
TopP float64 `json:"top_p,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Description resp.Field
- Instructions resp.Field
- Metadata resp.Field
- Model resp.Field
- Name resp.Field
- Object resp.Field
- Tools resp.Field
- ResponseFormat resp.Field
- Temperature resp.Field
- ToolResources resp.Field
- TopP resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Description respjson.Field
+ Instructions respjson.Field
+ Metadata respjson.Field
+ Model respjson.Field
+ Name respjson.Field
+ Object respjson.Field
+ Tools respjson.Field
+ ResponseFormat respjson.Field
+ Temperature respjson.Field
+ ToolResources respjson.Field
+ TopP respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -215,12 +212,11 @@ func (r *Assistant) UnmarshalJSON(data []byte) error {
type AssistantToolResources struct {
CodeInterpreter AssistantToolResourcesCodeInterpreter `json:"code_interpreter"`
FileSearch AssistantToolResourcesFileSearch `json:"file_search"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CodeInterpreter resp.Field
- FileSearch resp.Field
- ExtraFields map[string]resp.Field
+ CodeInterpreter respjson.Field
+ FileSearch respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -236,11 +232,10 @@ type AssistantToolResourcesCodeInterpreter struct {
// available to the `code_interpreterโ tool. There can be a maximum of 20 files
// associated with the tool.
FileIDs []string `json:"file_ids"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileIDs resp.Field
- ExtraFields map[string]resp.Field
+ FileIDs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -257,11 +252,10 @@ type AssistantToolResourcesFileSearch struct {
// attached to this assistant. There can be a maximum of 1 vector store attached to
// the assistant.
VectorStoreIDs []string `json:"vector_store_ids"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- VectorStoreIDs resp.Field
- ExtraFields map[string]resp.Field
+ VectorStoreIDs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -276,13 +270,12 @@ type AssistantDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.AssistantDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -337,44 +330,76 @@ type AssistantStreamEventUnion struct {
// This field is from variant [AssistantStreamEventThreadCreated].
Enabled bool `json:"enabled"`
JSON struct {
- Data resp.Field
- Event resp.Field
- Enabled resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ Enabled respjson.Field
raw string
} `json:"-"`
}
+// anyAssistantStreamEvent is implemented by each variant of
+// [AssistantStreamEventUnion] to add type safety for the return type of
+// [AssistantStreamEventUnion.AsAny]
+type anyAssistantStreamEvent interface {
+ implAssistantStreamEventUnion()
+}
+
+func (AssistantStreamEventThreadCreated) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunCreated) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunQueued) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunInProgress) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunRequiresAction) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunCompleted) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunIncomplete) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunFailed) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunCancelling) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunCancelled) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunExpired) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepCreated) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepInProgress) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepDelta) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepCompleted) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepFailed) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepCancelled) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadRunStepExpired) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadMessageCreated) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadMessageInProgress) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadMessageDelta) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadMessageCompleted) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventThreadMessageIncomplete) implAssistantStreamEventUnion() {}
+func (AssistantStreamEventErrorEvent) implAssistantStreamEventUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := AssistantStreamEventUnion.AsAny().(type) {
-// case AssistantStreamEventThreadCreated:
-// case AssistantStreamEventThreadRunCreated:
-// case AssistantStreamEventThreadRunQueued:
-// case AssistantStreamEventThreadRunInProgress:
-// case AssistantStreamEventThreadRunRequiresAction:
-// case AssistantStreamEventThreadRunCompleted:
-// case AssistantStreamEventThreadRunIncomplete:
-// case AssistantStreamEventThreadRunFailed:
-// case AssistantStreamEventThreadRunCancelling:
-// case AssistantStreamEventThreadRunCancelled:
-// case AssistantStreamEventThreadRunExpired:
-// case AssistantStreamEventThreadRunStepCreated:
-// case AssistantStreamEventThreadRunStepInProgress:
-// case AssistantStreamEventThreadRunStepDelta:
-// case AssistantStreamEventThreadRunStepCompleted:
-// case AssistantStreamEventThreadRunStepFailed:
-// case AssistantStreamEventThreadRunStepCancelled:
-// case AssistantStreamEventThreadRunStepExpired:
-// case AssistantStreamEventThreadMessageCreated:
-// case AssistantStreamEventThreadMessageInProgress:
-// case AssistantStreamEventThreadMessageDelta:
-// case AssistantStreamEventThreadMessageCompleted:
-// case AssistantStreamEventThreadMessageIncomplete:
-// case AssistantStreamEventErrorEvent:
+// case openai.AssistantStreamEventThreadCreated:
+// case openai.AssistantStreamEventThreadRunCreated:
+// case openai.AssistantStreamEventThreadRunQueued:
+// case openai.AssistantStreamEventThreadRunInProgress:
+// case openai.AssistantStreamEventThreadRunRequiresAction:
+// case openai.AssistantStreamEventThreadRunCompleted:
+// case openai.AssistantStreamEventThreadRunIncomplete:
+// case openai.AssistantStreamEventThreadRunFailed:
+// case openai.AssistantStreamEventThreadRunCancelling:
+// case openai.AssistantStreamEventThreadRunCancelled:
+// case openai.AssistantStreamEventThreadRunExpired:
+// case openai.AssistantStreamEventThreadRunStepCreated:
+// case openai.AssistantStreamEventThreadRunStepInProgress:
+// case openai.AssistantStreamEventThreadRunStepDelta:
+// case openai.AssistantStreamEventThreadRunStepCompleted:
+// case openai.AssistantStreamEventThreadRunStepFailed:
+// case openai.AssistantStreamEventThreadRunStepCancelled:
+// case openai.AssistantStreamEventThreadRunStepExpired:
+// case openai.AssistantStreamEventThreadMessageCreated:
+// case openai.AssistantStreamEventThreadMessageInProgress:
+// case openai.AssistantStreamEventThreadMessageDelta:
+// case openai.AssistantStreamEventThreadMessageCompleted:
+// case openai.AssistantStreamEventThreadMessageIncomplete:
+// case openai.AssistantStreamEventErrorEvent:
// default:
// fmt.Errorf("no variant present")
// }
-func (u AssistantStreamEventUnion) AsAny() any {
+func (u AssistantStreamEventUnion) AsAny() anyAssistantStreamEvent {
switch u.Event {
case "thread.created":
return u.AsThreadCreated()
@@ -632,46 +657,46 @@ type AssistantStreamEventUnionData struct {
// This field is from variant [shared.ErrorObject].
Param string `json:"param"`
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Metadata resp.Field
- Object resp.Field
- ToolResources resp.Field
- AssistantID resp.Field
- CancelledAt resp.Field
- CompletedAt resp.Field
- ExpiresAt resp.Field
- FailedAt resp.Field
- IncompleteDetails resp.Field
- Instructions resp.Field
- LastError resp.Field
- MaxCompletionTokens resp.Field
- MaxPromptTokens resp.Field
- Model resp.Field
- ParallelToolCalls resp.Field
- RequiredAction resp.Field
- ResponseFormat resp.Field
- StartedAt resp.Field
- Status resp.Field
- ThreadID resp.Field
- ToolChoice resp.Field
- Tools resp.Field
- TruncationStrategy resp.Field
- Usage resp.Field
- Temperature resp.Field
- TopP resp.Field
- ExpiredAt resp.Field
- RunID resp.Field
- StepDetails resp.Field
- Type resp.Field
- Delta resp.Field
- Attachments resp.Field
- Content resp.Field
- IncompleteAt resp.Field
- Role resp.Field
- Code resp.Field
- Message resp.Field
- Param resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Metadata respjson.Field
+ Object respjson.Field
+ ToolResources respjson.Field
+ AssistantID respjson.Field
+ CancelledAt respjson.Field
+ CompletedAt respjson.Field
+ ExpiresAt respjson.Field
+ FailedAt respjson.Field
+ IncompleteDetails respjson.Field
+ Instructions respjson.Field
+ LastError respjson.Field
+ MaxCompletionTokens respjson.Field
+ MaxPromptTokens respjson.Field
+ Model respjson.Field
+ ParallelToolCalls respjson.Field
+ RequiredAction respjson.Field
+ ResponseFormat respjson.Field
+ StartedAt respjson.Field
+ Status respjson.Field
+ ThreadID respjson.Field
+ ToolChoice respjson.Field
+ Tools respjson.Field
+ TruncationStrategy respjson.Field
+ Usage respjson.Field
+ Temperature respjson.Field
+ TopP respjson.Field
+ ExpiredAt respjson.Field
+ RunID respjson.Field
+ StepDetails respjson.Field
+ Type respjson.Field
+ Delta respjson.Field
+ Attachments respjson.Field
+ Content respjson.Field
+ IncompleteAt respjson.Field
+ Role respjson.Field
+ Code respjson.Field
+ Message respjson.Field
+ Param respjson.Field
raw string
} `json:"-"`
}
@@ -689,7 +714,7 @@ func (r *AssistantStreamEventUnionData) UnmarshalJSON(data []byte) error {
type AssistantStreamEventUnionDataIncompleteDetails struct {
Reason string `json:"reason"`
JSON struct {
- Reason resp.Field
+ Reason respjson.Field
raw string
} `json:"-"`
}
@@ -708,8 +733,8 @@ type AssistantStreamEventUnionDataLastError struct {
Code string `json:"code"`
Message string `json:"message"`
JSON struct {
- Code resp.Field
- Message resp.Field
+ Code respjson.Field
+ Message respjson.Field
raw string
} `json:"-"`
}
@@ -729,9 +754,9 @@ type AssistantStreamEventUnionDataUsage struct {
PromptTokens int64 `json:"prompt_tokens"`
TotalTokens int64 `json:"total_tokens"`
JSON struct {
- CompletionTokens resp.Field
- PromptTokens resp.Field
- TotalTokens resp.Field
+ CompletionTokens respjson.Field
+ PromptTokens respjson.Field
+ TotalTokens respjson.Field
raw string
} `json:"-"`
}
@@ -754,9 +779,9 @@ type AssistantStreamEventUnionDataDelta struct {
// This field is from variant [MessageDelta].
Role MessageDeltaRole `json:"role"`
JSON struct {
- StepDetails resp.Field
- Content resp.Field
- Role resp.Field
+ StepDetails respjson.Field
+ Content respjson.Field
+ Role respjson.Field
raw string
} `json:"-"`
}
@@ -775,13 +800,12 @@ type AssistantStreamEventThreadCreated struct {
Event constant.ThreadCreated `json:"event,required"`
// Whether to enable input audio transcription.
Enabled bool `json:"enabled"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- Enabled resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ Enabled respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -799,12 +823,11 @@ type AssistantStreamEventThreadRunCreated struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunCreated `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -822,12 +845,11 @@ type AssistantStreamEventThreadRunQueued struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunQueued `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -845,12 +867,11 @@ type AssistantStreamEventThreadRunInProgress struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunInProgress `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -868,12 +889,11 @@ type AssistantStreamEventThreadRunRequiresAction struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunRequiresAction `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -891,12 +911,11 @@ type AssistantStreamEventThreadRunCompleted struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunCompleted `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -914,12 +933,11 @@ type AssistantStreamEventThreadRunIncomplete struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunIncomplete `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -937,12 +955,11 @@ type AssistantStreamEventThreadRunFailed struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunFailed `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -960,12 +977,11 @@ type AssistantStreamEventThreadRunCancelling struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunCancelling `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -983,12 +999,11 @@ type AssistantStreamEventThreadRunCancelled struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunCancelled `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1006,12 +1021,11 @@ type AssistantStreamEventThreadRunExpired struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Run `json:"data,required"`
Event constant.ThreadRunExpired `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1029,12 +1043,11 @@ type AssistantStreamEventThreadRunStepCreated struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepCreated `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1052,12 +1065,11 @@ type AssistantStreamEventThreadRunStepInProgress struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepInProgress `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1076,12 +1088,11 @@ type AssistantStreamEventThreadRunStepDelta struct {
// streaming.
Data RunStepDeltaEvent `json:"data,required"`
Event constant.ThreadRunStepDelta `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1099,12 +1110,11 @@ type AssistantStreamEventThreadRunStepCompleted struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepCompleted `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1122,12 +1132,11 @@ type AssistantStreamEventThreadRunStepFailed struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepFailed `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1145,12 +1154,11 @@ type AssistantStreamEventThreadRunStepCancelled struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepCancelled `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1168,12 +1176,11 @@ type AssistantStreamEventThreadRunStepExpired struct {
// Represents a step in execution of a run.
Data RunStep `json:"data,required"`
Event constant.ThreadRunStepExpired `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1192,12 +1199,11 @@ type AssistantStreamEventThreadMessageCreated struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Message `json:"data,required"`
Event constant.ThreadMessageCreated `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1216,12 +1222,11 @@ type AssistantStreamEventThreadMessageInProgress struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Message `json:"data,required"`
Event constant.ThreadMessageInProgress `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1240,12 +1245,11 @@ type AssistantStreamEventThreadMessageDelta struct {
// streaming.
Data MessageDeltaEvent `json:"data,required"`
Event constant.ThreadMessageDelta `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1264,12 +1268,11 @@ type AssistantStreamEventThreadMessageCompleted struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Message `json:"data,required"`
Event constant.ThreadMessageCompleted `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1288,12 +1291,11 @@ type AssistantStreamEventThreadMessageIncomplete struct {
// [thread](https://platform.openai.com/docs/api-reference/threads).
Data Message `json:"data,required"`
Event constant.ThreadMessageIncomplete `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1310,12 +1312,11 @@ func (r *AssistantStreamEventThreadMessageIncomplete) UnmarshalJSON(data []byte)
type AssistantStreamEventErrorEvent struct {
Data shared.ErrorObject `json:"data,required"`
Event constant.Error `json:"event,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Event resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Event respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1340,23 +1341,33 @@ type AssistantToolUnion struct {
// This field is from variant [FunctionTool].
Function shared.FunctionDefinition `json:"function"`
JSON struct {
- Type resp.Field
- FileSearch resp.Field
- Function resp.Field
+ Type respjson.Field
+ FileSearch respjson.Field
+ Function respjson.Field
raw string
} `json:"-"`
}
+// anyAssistantTool is implemented by each variant of [AssistantToolUnion] to add
+// type safety for the return type of [AssistantToolUnion.AsAny]
+type anyAssistantTool interface {
+ implAssistantToolUnion()
+}
+
+func (CodeInterpreterTool) implAssistantToolUnion() {}
+func (FileSearchTool) implAssistantToolUnion() {}
+func (FunctionTool) implAssistantToolUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := AssistantToolUnion.AsAny().(type) {
-// case CodeInterpreterTool:
-// case FileSearchTool:
-// case FunctionTool:
+// case openai.CodeInterpreterTool:
+// case openai.FileSearchTool:
+// case openai.FunctionTool:
// default:
// fmt.Errorf("no variant present")
// }
-func (u AssistantToolUnion) AsAny() any {
+func (u AssistantToolUnion) AsAny() anyAssistantTool {
switch u.Type {
case "code_interpreter":
return u.AsCodeInterpreter()
@@ -1394,9 +1405,9 @@ func (r *AssistantToolUnion) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// AssistantToolUnionParam.IsOverridden()
+// AssistantToolUnionParam.Overrides()
func (r AssistantToolUnion) ToParam() AssistantToolUnionParam {
- return param.OverrideObj[AssistantToolUnionParam](r.RawJSON())
+ return param.Override[AssistantToolUnionParam](json.RawMessage(r.RawJSON()))
}
func AssistantToolParamOfFunction(function shared.FunctionDefinitionParam) AssistantToolUnionParam {
@@ -1415,11 +1426,11 @@ type AssistantToolUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u AssistantToolUnionParam) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u AssistantToolUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[AssistantToolUnionParam](u.OfCodeInterpreter, u.OfFileSearch, u.OfFunction)
+ return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch, u.OfFunction)
+}
+func (u *AssistantToolUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *AssistantToolUnionParam) asAny() any {
@@ -1464,32 +1475,19 @@ func (u AssistantToolUnionParam) GetType() *string {
func init() {
apijson.RegisterUnion[AssistantToolUnionParam](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(CodeInterpreterToolParam{}),
- DiscriminatorValue: "code_interpreter",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(FileSearchToolParam{}),
- DiscriminatorValue: "file_search",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(FunctionToolParam{}),
- DiscriminatorValue: "function",
- },
+ apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"),
+ apijson.Discriminator[FileSearchToolParam]("file_search"),
+ apijson.Discriminator[FunctionToolParam]("function"),
)
}
type CodeInterpreterTool struct {
// The type of tool being defined: `code_interpreter`
Type constant.CodeInterpreter `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1504,39 +1502,43 @@ func (r *CodeInterpreterTool) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// CodeInterpreterToolParam.IsOverridden()
+// CodeInterpreterToolParam.Overrides()
func (r CodeInterpreterTool) ToParam() CodeInterpreterToolParam {
- return param.OverrideObj[CodeInterpreterToolParam](r.RawJSON())
+ return param.Override[CodeInterpreterToolParam](json.RawMessage(r.RawJSON()))
}
-// The property Type is required.
+func NewCodeInterpreterToolParam() CodeInterpreterToolParam {
+ return CodeInterpreterToolParam{
+ Type: "code_interpreter",
+ }
+}
+
+// This struct has a constant value, construct it with
+// [NewCodeInterpreterToolParam].
type CodeInterpreterToolParam struct {
// The type of tool being defined: `code_interpreter`
- //
- // This field can be elided, and will marshal its zero value as "code_interpreter".
Type constant.CodeInterpreter `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f CodeInterpreterToolParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r CodeInterpreterToolParam) MarshalJSON() (data []byte, err error) {
type shadow CodeInterpreterToolParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *CodeInterpreterToolParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type FileSearchTool struct {
// The type of tool being defined: `file_search`
Type constant.FileSearch `json:"type,required"`
// Overrides for the file search tool.
FileSearch FileSearchToolFileSearch `json:"file_search"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- FileSearch resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ FileSearch respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1551,9 +1553,9 @@ func (r *FileSearchTool) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// FileSearchToolParam.IsOverridden()
+// FileSearchToolParam.Overrides()
func (r FileSearchTool) ToParam() FileSearchToolParam {
- return param.OverrideObj[FileSearchToolParam](r.RawJSON())
+ return param.Override[FileSearchToolParam](json.RawMessage(r.RawJSON()))
}
// Overrides for the file search tool.
@@ -1574,12 +1576,11 @@ type FileSearchToolFileSearch struct {
// [file search tool documentation](https://platform.openai.com/docs/assistants/tools/file-search#customizing-file-search-settings)
// for more information.
RankingOptions FileSearchToolFileSearchRankingOptions `json:"ranking_options"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- MaxNumResults resp.Field
- RankingOptions resp.Field
- ExtraFields map[string]resp.Field
+ MaxNumResults respjson.Field
+ RankingOptions respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1605,12 +1606,11 @@ type FileSearchToolFileSearchRankingOptions struct {
//
// Any of "auto", "default_2024_08_21".
Ranker string `json:"ranker"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ScoreThreshold resp.Field
- Ranker resp.Field
- ExtraFields map[string]resp.Field
+ ScoreThreshold respjson.Field
+ Ranker respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1632,13 +1632,13 @@ type FileSearchToolParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FileSearchToolParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r FileSearchToolParam) MarshalJSON() (data []byte, err error) {
type shadow FileSearchToolParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FileSearchToolParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Overrides for the file search tool.
type FileSearchToolFileSearchParam struct {
@@ -8,17 +8,15 @@ import (
"errors"
"fmt"
"net/http"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// BetaThreadService contains methods and other services that help with interacting
@@ -27,9 +25,13 @@ import (
// Note, unlike clients, this service does not read variables from the environment
// automatically. You should not instantiate this service directly, and instead use
// the [NewBetaThreadService] method instead.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
type BetaThreadService struct {
- Options []option.RequestOption
- Runs BetaThreadRunService
+ Options []option.RequestOption
+ // Deprecated: The Assistants API is deprecated in favor of the Responses API
+ Runs BetaThreadRunService
+ // Deprecated: The Assistants API is deprecated in favor of the Responses API
Messages BetaThreadMessageService
}
@@ -45,6 +47,8 @@ func NewBetaThreadService(opts ...option.RequestOption) (r BetaThreadService) {
}
// Create a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) New(ctx context.Context, body BetaThreadNewParams, opts ...option.RequestOption) (res *Thread, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -53,16 +57,9 @@ func (r *BetaThreadService) New(ctx context.Context, body BetaThreadNewParams, o
return
}
-// Create a thread and run it in one request. Poll the API until the run is complete.
-func (r *BetaThreadService) NewAndRunPoll(ctx context.Context, body BetaThreadNewAndRunParams, pollIntervalMs int, opts ...option.RequestOption) (res *Run, err error) {
- run, err := r.NewAndRun(ctx, body, opts...)
- if err != nil {
- return nil, err
- }
- return r.Runs.PollStatus(ctx, run.ThreadID, run.ID, pollIntervalMs, opts...)
-}
-
// Retrieves a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) Get(ctx context.Context, threadID string, opts ...option.RequestOption) (res *Thread, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -76,6 +73,8 @@ func (r *BetaThreadService) Get(ctx context.Context, threadID string, opts ...op
}
// Modifies a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) Update(ctx context.Context, threadID string, body BetaThreadUpdateParams, opts ...option.RequestOption) (res *Thread, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -89,6 +88,8 @@ func (r *BetaThreadService) Update(ctx context.Context, threadID string, body Be
}
// Delete a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) Delete(ctx context.Context, threadID string, opts ...option.RequestOption) (res *ThreadDeleted, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -102,6 +103,8 @@ func (r *BetaThreadService) Delete(ctx context.Context, threadID string, opts ..
}
// Create a thread and run it in one request.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) NewAndRun(ctx context.Context, body BetaThreadNewAndRunParams, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -111,6 +114,8 @@ func (r *BetaThreadService) NewAndRun(ctx context.Context, body BetaThreadNewAnd
}
// Create a thread and run it in one request.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadService) NewAndRunStreaming(ctx context.Context, body BetaThreadNewAndRunParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) {
var (
raw *http.Response
@@ -139,9 +144,9 @@ type AssistantResponseFormatOptionUnion struct {
// This field is from variant [shared.ResponseFormatJSONSchema].
JSONSchema shared.ResponseFormatJSONSchemaJSONSchema `json:"json_schema"`
JSON struct {
- OfAuto resp.Field
- Type resp.Field
- JSONSchema resp.Field
+ OfAuto respjson.Field
+ Type respjson.Field
+ JSONSchema respjson.Field
raw string
} `json:"-"`
}
@@ -178,9 +183,9 @@ func (r *AssistantResponseFormatOptionUnion) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// AssistantResponseFormatOptionUnionParam.IsOverridden()
+// AssistantResponseFormatOptionUnionParam.Overrides()
func (r AssistantResponseFormatOptionUnion) ToParam() AssistantResponseFormatOptionUnionParam {
- return param.OverrideObj[AssistantResponseFormatOptionUnionParam](r.RawJSON())
+ return param.Override[AssistantResponseFormatOptionUnionParam](json.RawMessage(r.RawJSON()))
}
func AssistantResponseFormatOptionParamOfAuto() AssistantResponseFormatOptionUnionParam {
@@ -197,8 +202,7 @@ func AssistantResponseFormatOptionParamOfJSONSchema(jsonSchema shared.ResponseFo
//
// Use [param.IsOmitted] to confirm if a field is set.
type AssistantResponseFormatOptionUnionParam struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
+ // Construct this variant with constant.ValueOf[constant.Auto]()
OfAuto constant.Auto `json:",omitzero,inline"`
OfText *shared.ResponseFormatTextParam `json:",omitzero,inline"`
OfJSONObject *shared.ResponseFormatJSONObjectParam `json:",omitzero,inline"`
@@ -206,13 +210,11 @@ type AssistantResponseFormatOptionUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u AssistantResponseFormatOptionUnionParam) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u AssistantResponseFormatOptionUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[AssistantResponseFormatOptionUnionParam](u.OfAuto, u.OfText, u.OfJSONObject, u.OfJSONSchema)
+ return param.MarshalUnion(u, u.OfAuto, u.OfText, u.OfJSONObject, u.OfJSONSchema)
+}
+func (u *AssistantResponseFormatOptionUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *AssistantResponseFormatOptionUnionParam) asAny() any {
@@ -256,12 +258,11 @@ type AssistantToolChoice struct {
// Any of "function", "code_interpreter", "file_search".
Type AssistantToolChoiceType `json:"type,required"`
Function AssistantToolChoiceFunction `json:"function"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- Function resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ Function respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -276,9 +277,9 @@ func (r *AssistantToolChoice) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// AssistantToolChoiceParam.IsOverridden()
+// AssistantToolChoiceParam.Overrides()
func (r AssistantToolChoice) ToParam() AssistantToolChoiceParam {
- return param.OverrideObj[AssistantToolChoiceParam](r.RawJSON())
+ return param.Override[AssistantToolChoiceParam](json.RawMessage(r.RawJSON()))
}
// The type of the tool. If type is `function`, the function name must be set
@@ -303,22 +304,21 @@ type AssistantToolChoiceParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AssistantToolChoiceParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r AssistantToolChoiceParam) MarshalJSON() (data []byte, err error) {
type shadow AssistantToolChoiceParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *AssistantToolChoiceParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type AssistantToolChoiceFunction struct {
// The name of the function to call.
Name string `json:"name,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Name resp.Field
- ExtraFields map[string]resp.Field
+ Name respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -334,9 +334,9 @@ func (r *AssistantToolChoiceFunction) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// AssistantToolChoiceFunctionParam.IsOverridden()
+// AssistantToolChoiceFunctionParam.Overrides()
func (r AssistantToolChoiceFunction) ToParam() AssistantToolChoiceFunctionParam {
- return param.OverrideObj[AssistantToolChoiceFunctionParam](r.RawJSON())
+ return param.Override[AssistantToolChoiceFunctionParam](json.RawMessage(r.RawJSON()))
}
// The property Name is required.
@@ -346,13 +346,13 @@ type AssistantToolChoiceFunctionParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AssistantToolChoiceFunctionParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r AssistantToolChoiceFunctionParam) MarshalJSON() (data []byte, err error) {
type shadow AssistantToolChoiceFunctionParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *AssistantToolChoiceFunctionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// AssistantToolChoiceOptionUnion contains all possible properties and values from
// [string], [AssistantToolChoice].
@@ -369,9 +369,9 @@ type AssistantToolChoiceOptionUnion struct {
// This field is from variant [AssistantToolChoice].
Function AssistantToolChoiceFunction `json:"function"`
JSON struct {
- OfAuto resp.Field
- Type resp.Field
- Function resp.Field
+ OfAuto respjson.Field
+ Type respjson.Field
+ Function respjson.Field
raw string
} `json:"-"`
}
@@ -398,11 +398,23 @@ func (r *AssistantToolChoiceOptionUnion) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// AssistantToolChoiceOptionUnionParam.IsOverridden()
+// AssistantToolChoiceOptionUnionParam.Overrides()
func (r AssistantToolChoiceOptionUnion) ToParam() AssistantToolChoiceOptionUnionParam {
- return param.OverrideObj[AssistantToolChoiceOptionUnionParam](r.RawJSON())
+ return param.Override[AssistantToolChoiceOptionUnionParam](json.RawMessage(r.RawJSON()))
}
+// `none` means the model will not call any tools and instead generates a message.
+// `auto` means the model can pick between generating a message or calling one or
+// more tools. `required` means the model must call one or more tools before
+// responding to the user.
+type AssistantToolChoiceOptionAuto string
+
+const (
+ AssistantToolChoiceOptionAutoNone AssistantToolChoiceOptionAuto = "none"
+ AssistantToolChoiceOptionAutoAuto AssistantToolChoiceOptionAuto = "auto"
+ AssistantToolChoiceOptionAutoRequired AssistantToolChoiceOptionAuto = "required"
+)
+
func AssistantToolChoiceOptionParamOfAssistantToolChoice(type_ AssistantToolChoiceType) AssistantToolChoiceOptionUnionParam {
var variant AssistantToolChoiceParam
variant.Type = type_
@@ -419,13 +431,11 @@ type AssistantToolChoiceOptionUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u AssistantToolChoiceOptionUnionParam) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u AssistantToolChoiceOptionUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[AssistantToolChoiceOptionUnionParam](u.OfAuto, u.OfAssistantToolChoice)
+ return param.MarshalUnion(u, u.OfAuto, u.OfAssistantToolChoice)
+}
+func (u *AssistantToolChoiceOptionUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *AssistantToolChoiceOptionUnionParam) asAny() any {
@@ -437,22 +447,6 @@ func (u *AssistantToolChoiceOptionUnionParam) asAny() any {
return nil
}
-// Returns a pointer to the underlying variant's property, if present.
-func (u AssistantToolChoiceOptionUnionParam) GetType() *string {
- if vt := u.OfAssistantToolChoice; vt != nil {
- return (*string)(&vt.Type)
- }
- return nil
-}
-
-// Returns a pointer to the underlying variant's property, if present.
-func (u AssistantToolChoiceOptionUnionParam) GetFunction() *AssistantToolChoiceFunctionParam {
- if vt := u.OfAssistantToolChoice; vt != nil {
- return &vt.Function
- }
- return nil
-}
-
// Represents a thread that contains
// [messages](https://platform.openai.com/docs/api-reference/messages).
type Thread struct {
@@ -474,15 +468,14 @@ type Thread struct {
// `code_interpreter` tool requires a list of file IDs, while the `file_search`
// tool requires a list of vector store IDs.
ToolResources ThreadToolResources `json:"tool_resources,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Metadata resp.Field
- Object resp.Field
- ToolResources resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Metadata respjson.Field
+ Object respjson.Field
+ ToolResources respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -500,12 +493,11 @@ func (r *Thread) UnmarshalJSON(data []byte) error {
type ThreadToolResources struct {
CodeInterpreter ThreadToolResourcesCodeInterpreter `json:"code_interpreter"`
FileSearch ThreadToolResourcesFileSearch `json:"file_search"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CodeInterpreter resp.Field
- FileSearch resp.Field
- ExtraFields map[string]resp.Field
+ CodeInterpreter respjson.Field
+ FileSearch respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -521,11 +513,10 @@ type ThreadToolResourcesCodeInterpreter struct {
// available to the `code_interpreter` tool. There can be a maximum of 20 files
// associated with the tool.
FileIDs []string `json:"file_ids"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileIDs resp.Field
- ExtraFields map[string]resp.Field
+ FileIDs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -542,11 +533,10 @@ type ThreadToolResourcesFileSearch struct {
// attached to this thread. There can be a maximum of 1 vector store attached to
// the thread.
VectorStoreIDs []string `json:"vector_store_ids"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- VectorStoreIDs resp.Field
- ExtraFields map[string]resp.Field
+ VectorStoreIDs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -561,13 +551,12 @@ type ThreadDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.ThreadDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -585,7 +574,7 @@ type BetaThreadNewParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
// `code_interpreter` tool requires a list of file IDs, while the `file_search`
@@ -597,14 +586,13 @@ type BetaThreadNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BetaThreadNewParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The properties Content, Role are required.
type BetaThreadNewParamsMessage struct {
@@ -627,21 +615,21 @@ type BetaThreadNewParamsMessage struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsMessage) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r BetaThreadNewParamsMessage) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsMessage
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsMessage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[BetaThreadNewParamsMessage](
- "Role", false, "user", "assistant",
+ "role", "user", "assistant",
)
}
@@ -654,13 +642,11 @@ type BetaThreadNewParamsMessageContentUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadNewParamsMessageContentUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadNewParamsMessageContentUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadNewParamsMessageContentUnion](u.OfString, u.OfArrayOfContentParts)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts)
+}
+func (u *BetaThreadNewParamsMessageContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadNewParamsMessageContentUnion) asAny() any {
@@ -680,15 +666,13 @@ type BetaThreadNewParamsMessageAttachment struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsMessageAttachment) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsMessageAttachment) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsMessageAttachment
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsMessageAttachment) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -699,13 +683,11 @@ type BetaThreadNewParamsMessageAttachmentToolUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadNewParamsMessageAttachmentToolUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadNewParamsMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadNewParamsMessageAttachmentToolUnion](u.OfCodeInterpreter, u.OfFileSearch)
+ return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch)
+}
+func (u *BetaThreadNewParamsMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadNewParamsMessageAttachmentToolUnion) asAny() any {
@@ -730,37 +712,32 @@ func (u BetaThreadNewParamsMessageAttachmentToolUnion) GetType() *string {
func init() {
apijson.RegisterUnion[BetaThreadNewParamsMessageAttachmentToolUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(CodeInterpreterToolParam{}),
- DiscriminatorValue: "code_interpreter",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(BetaThreadNewParamsMessageAttachmentToolFileSearch{}),
- DiscriminatorValue: "file_search",
- },
+ apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"),
+ apijson.Discriminator[BetaThreadNewParamsMessageAttachmentToolFileSearch]("file_search"),
)
}
-// The property Type is required.
+func NewBetaThreadNewParamsMessageAttachmentToolFileSearch() BetaThreadNewParamsMessageAttachmentToolFileSearch {
+ return BetaThreadNewParamsMessageAttachmentToolFileSearch{
+ Type: "file_search",
+ }
+}
+
+// This struct has a constant value, construct it with
+// [NewBetaThreadNewParamsMessageAttachmentToolFileSearch].
type BetaThreadNewParamsMessageAttachmentToolFileSearch struct {
// The type of tool being defined: `file_search`
- //
- // This field can be elided, and will marshal its zero value as "file_search".
Type constant.FileSearch `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsMessageAttachmentToolFileSearch) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsMessageAttachmentToolFileSearch
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
@@ -772,13 +749,13 @@ type BetaThreadNewParamsToolResources struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResources) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r BetaThreadNewParamsToolResources) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResources
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResources) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadNewParamsToolResourcesCodeInterpreter struct {
// A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made
@@ -788,15 +765,13 @@ type BetaThreadNewParamsToolResourcesCodeInterpreter struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesCodeInterpreter) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesCodeInterpreter
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadNewParamsToolResourcesFileSearch struct {
// The
@@ -812,15 +787,13 @@ type BetaThreadNewParamsToolResourcesFileSearch struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesFileSearch) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesFileSearch
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadNewParamsToolResourcesFileSearchVectorStore struct {
// Set of 16 key-value pairs that can be attached to an object. This can be useful
@@ -829,7 +802,7 @@ type BetaThreadNewParamsToolResourcesFileSearchVectorStore struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// The chunking strategy used to chunk the file(s). If not set, will use the `auto`
// strategy.
ChunkingStrategy BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion `json:"chunking_strategy,omitzero"`
@@ -840,15 +813,13 @@ type BetaThreadNewParamsToolResourcesFileSearchVectorStore struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesFileSearchVectorStore) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesFileSearchVectorStore) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStore
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStore) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -859,13 +830,11 @@ type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion](u.OfAuto, u.OfStatic)
+ return param.MarshalUnion(u, u.OfAuto, u.OfStatic)
+}
+func (u *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion) asAny() any {
@@ -898,40 +867,35 @@ func (u BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUni
func init() {
apijson.RegisterUnion[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto{}),
- DiscriminatorValue: "auto",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic{}),
- DiscriminatorValue: "static",
- },
+ apijson.Discriminator[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto]("auto"),
+ apijson.Discriminator[BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic]("static"),
)
}
+func NewBetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto() BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto {
+ return BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto{
+ Type: "auto",
+ }
+}
+
// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of
// `800` and `chunk_overlap_tokens` of `400`.
//
-// The property Type is required.
+// This struct has a constant value, construct it with
+// [NewBetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto].
type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto struct {
// Always `auto`.
- //
- // This field can be elided, and will marshal its zero value as "auto".
Type constant.Auto `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyAuto) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The properties Static, Type are required.
type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic struct {
@@ -943,15 +907,13 @@ type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The properties ChunkOverlapTokens, MaxChunkSizeTokens are required.
type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic struct {
@@ -965,15 +927,13 @@ type BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStatic
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewParamsToolResourcesFileSearchVectorStoreChunkingStrategyStaticStatic) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadUpdateParams struct {
// Set of 16 key-value pairs that can be attached to an object. This can be useful
@@ -982,7 +942,7 @@ type BetaThreadUpdateParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
// `code_interpreter` tool requires a list of file IDs, while the `file_search`
@@ -991,14 +951,13 @@ type BetaThreadUpdateParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadUpdateParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BetaThreadUpdateParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadUpdateParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadUpdateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
@@ -1010,15 +969,13 @@ type BetaThreadUpdateParamsToolResources struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadUpdateParamsToolResources) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadUpdateParamsToolResources) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadUpdateParamsToolResources
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadUpdateParamsToolResources) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadUpdateParamsToolResourcesCodeInterpreter struct {
// A list of [file](https://platform.openai.com/docs/api-reference/files) IDs made
@@ -1028,15 +985,13 @@ type BetaThreadUpdateParamsToolResourcesCodeInterpreter struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadUpdateParamsToolResourcesCodeInterpreter) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadUpdateParamsToolResourcesCodeInterpreter) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadUpdateParamsToolResourcesCodeInterpreter
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadUpdateParamsToolResourcesCodeInterpreter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadUpdateParamsToolResourcesFileSearch struct {
// The
@@ -1047,15 +1002,13 @@ type BetaThreadUpdateParamsToolResourcesFileSearch struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadUpdateParamsToolResourcesFileSearch) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadUpdateParamsToolResourcesFileSearch) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadUpdateParamsToolResourcesFileSearch
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadUpdateParamsToolResourcesFileSearch) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadNewAndRunParams struct {
// The ID of the
@@ -1097,7 +1050,7 @@ type BetaThreadNewAndRunParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to
// be used to execute this run. If a value is provided here, it will override the
// model associated with the assistant. If not, the model associated with the
@@ -1110,7 +1063,7 @@ type BetaThreadNewAndRunParams struct {
ToolResources BetaThreadNewAndRunParamsToolResources `json:"tool_resources,omitzero"`
// Override the tools the assistant can use for this run. This is useful for
// modifying the behavior on a per-run basis.
- Tools []BetaThreadNewAndRunParamsToolUnion `json:"tools,omitzero"`
+ Tools []AssistantToolUnionParam `json:"tools,omitzero"`
// Controls for how a thread will be truncated prior to the run. Use this to
// control the intial context window of the run.
TruncationStrategy BetaThreadNewAndRunParamsTruncationStrategy `json:"truncation_strategy,omitzero"`
@@ -1149,14 +1102,13 @@ type BetaThreadNewAndRunParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewAndRunParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BetaThreadNewAndRunParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewAndRunParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewAndRunParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Options to create a new thread. If no thread is provided when running a request,
// an empty thread will be created.
@@ -1167,7 +1119,7 @@ type BetaThreadNewAndRunParamsThread struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
// `code_interpreter` tool requires a list of file IDs, while the `file_search`
@@ -1179,13 +1131,13 @@ type BetaThreadNewAndRunParamsThread struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewAndRunParamsThread) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r BetaThreadNewAndRunParamsThread) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewAndRunParamsThread
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewAndRunParamsThread) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The properties Content, Role are required.
type BetaThreadNewAndRunParamsThreadMessage struct {
@@ -1208,23 +1160,21 @@ type BetaThreadNewAndRunParamsThreadMessage struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewAndRunParamsThreadMessage) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewAndRunParamsThreadMessage) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewAndRunParamsThreadMessage
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewAndRunParamsThreadMessage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[BetaThreadNewAndRunParamsThreadMessage](
- "Role", false, "user", "assistant",
+ "role", "user", "assistant",
)
}
@@ -1237,13 +1187,11 @@ type BetaThreadNewAndRunParamsThreadMessageContentUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadNewAndRunParamsThreadMessageContentUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadNewAndRunParamsThreadMessageContentUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadNewAndRunParamsThreadMessageContentUnion](u.OfString, u.OfArrayOfContentParts)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts)
+}
+func (u *BetaThreadNewAndRunParamsThreadMessageContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadNewAndRunParamsThreadMessageContentUnion) asAny() any {
@@ -1263,15 +1211,13 @@ type BetaThreadNewAndRunParamsThreadMessageAttachment struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewAndRunParamsThreadMessageAttachment) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewAndRunParamsThreadMessageAttachment) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewAndRunParamsThreadMessageAttachment
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewAndRunParamsThreadMessageAttachment) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -1282,13 +1228,11 @@ type BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion](u.OfCodeInterpreter, u.OfFileSearch)
+ return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch)
+}
+func (u *BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) asAny() any {
@@ -1313,37 +1257,32 @@ func (u BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion) GetType() *st
func init() {
apijson.RegisterUnion[BetaThreadNewAndRunParamsThreadMessageAttachmentToolUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(CodeInterpreterToolParam{}),
- DiscriminatorValue: "code_interpreter",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch{}),
- DiscriminatorValue: "file_search",
- },
+ apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"),
+ apijson.Discriminator[BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch]("file_search"),
)
}
-// The property Type is required.
+func NewBetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch() BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch {
+ return BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch{
+ Type: "file_search",
+ }
+}
+
+// This struct has a constant value, construct it with
+// [NewBetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch].
type BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch struct {
// The type of tool being defined: `file_search`
- //
- // This field can be elided, and will marshal its zero value as "file_search".
Type constant.FileSearch `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadNewAndRunParamsThreadMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// A set of resources that are made available to the assistant's tools in this
// thread. The resources are specific to the type of tool. For example, the
@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/url"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
@@ -17,10 +16,9 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// BetaThreadMessageService contains methods and other services that help with
@@ -29,6 +27,8 @@ import (
// Note, unlike clients, this service does not read variables from the environment
// automatically. You should not instantiate this service directly, and instead use
// the [NewBetaThreadMessageService] method instead.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
type BetaThreadMessageService struct {
Options []option.RequestOption
}
@@ -43,6 +43,8 @@ func NewBetaThreadMessageService(opts ...option.RequestOption) (r BetaThreadMess
}
// Create a message.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) New(ctx context.Context, threadID string, body BetaThreadMessageNewParams, opts ...option.RequestOption) (res *Message, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -56,6 +58,8 @@ func (r *BetaThreadMessageService) New(ctx context.Context, threadID string, bod
}
// Retrieve a message.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) Get(ctx context.Context, threadID string, messageID string, opts ...option.RequestOption) (res *Message, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -73,6 +77,8 @@ func (r *BetaThreadMessageService) Get(ctx context.Context, threadID string, mes
}
// Modifies a message.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) Update(ctx context.Context, threadID string, messageID string, body BetaThreadMessageUpdateParams, opts ...option.RequestOption) (res *Message, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -90,6 +96,8 @@ func (r *BetaThreadMessageService) Update(ctx context.Context, threadID string,
}
// Returns a list of messages for a given thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) List(ctx context.Context, threadID string, query BetaThreadMessageListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Message], err error) {
var raw *http.Response
opts = append(r.Options[:], opts...)
@@ -112,11 +120,15 @@ func (r *BetaThreadMessageService) List(ctx context.Context, threadID string, qu
}
// Returns a list of messages for a given thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) ListAutoPaging(ctx context.Context, threadID string, query BetaThreadMessageListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Message] {
return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, query, opts...))
}
// Deletes a message.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadMessageService) Delete(ctx context.Context, threadID string, messageID string, opts ...option.RequestOption) (res *MessageDeleted, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -150,25 +162,34 @@ type AnnotationUnion struct {
// This field is from variant [FilePathAnnotation].
FilePath FilePathAnnotationFilePath `json:"file_path"`
JSON struct {
- EndIndex resp.Field
- FileCitation resp.Field
- StartIndex resp.Field
- Text resp.Field
- Type resp.Field
- FilePath resp.Field
+ EndIndex respjson.Field
+ FileCitation respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ Type respjson.Field
+ FilePath respjson.Field
raw string
} `json:"-"`
}
+// anyAnnotation is implemented by each variant of [AnnotationUnion] to add type
+// safety for the return type of [AnnotationUnion.AsAny]
+type anyAnnotation interface {
+ implAnnotationUnion()
+}
+
+func (FileCitationAnnotation) implAnnotationUnion() {}
+func (FilePathAnnotation) implAnnotationUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := AnnotationUnion.AsAny().(type) {
-// case FileCitationAnnotation:
-// case FilePathAnnotation:
+// case openai.FileCitationAnnotation:
+// case openai.FilePathAnnotation:
// default:
// fmt.Errorf("no variant present")
// }
-func (u AnnotationUnion) AsAny() any {
+func (u AnnotationUnion) AsAny() anyAnnotation {
switch u.Type {
case "file_citation":
return u.AsFileCitation()
@@ -213,26 +234,35 @@ type AnnotationDeltaUnion struct {
// This field is from variant [FilePathDeltaAnnotation].
FilePath FilePathDeltaAnnotationFilePath `json:"file_path"`
JSON struct {
- Index resp.Field
- Type resp.Field
- EndIndex resp.Field
- FileCitation resp.Field
- StartIndex resp.Field
- Text resp.Field
- FilePath resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ EndIndex respjson.Field
+ FileCitation respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ FilePath respjson.Field
raw string
} `json:"-"`
}
+// anyAnnotationDelta is implemented by each variant of [AnnotationDeltaUnion] to
+// add type safety for the return type of [AnnotationDeltaUnion.AsAny]
+type anyAnnotationDelta interface {
+ implAnnotationDeltaUnion()
+}
+
+func (FileCitationDeltaAnnotation) implAnnotationDeltaUnion() {}
+func (FilePathDeltaAnnotation) implAnnotationDeltaUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := AnnotationDeltaUnion.AsAny().(type) {
-// case FileCitationDeltaAnnotation:
-// case FilePathDeltaAnnotation:
+// case openai.FileCitationDeltaAnnotation:
+// case openai.FilePathDeltaAnnotation:
// default:
// fmt.Errorf("no variant present")
// }
-func (u AnnotationDeltaUnion) AsAny() any {
+func (u AnnotationDeltaUnion) AsAny() anyAnnotationDelta {
switch u.Type {
case "file_citation":
return u.AsFileCitation()
@@ -270,15 +300,14 @@ type FileCitationAnnotation struct {
Text string `json:"text,required"`
// Always `file_citation`.
Type constant.FileCitation `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- EndIndex resp.Field
- FileCitation resp.Field
- StartIndex resp.Field
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ EndIndex respjson.Field
+ FileCitation respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -292,11 +321,10 @@ func (r *FileCitationAnnotation) UnmarshalJSON(data []byte) error {
type FileCitationAnnotationFileCitation struct {
// The ID of the specific File the citation is from.
FileID string `json:"file_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -320,16 +348,15 @@ type FileCitationDeltaAnnotation struct {
StartIndex int64 `json:"start_index"`
// The text in the message content that needs to be replaced.
Text string `json:"text"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- EndIndex resp.Field
- FileCitation resp.Field
- StartIndex resp.Field
- Text resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ EndIndex respjson.Field
+ FileCitation respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -345,12 +372,11 @@ type FileCitationDeltaAnnotationFileCitation struct {
FileID string `json:"file_id"`
// The specific quote in the file.
Quote string `json:"quote"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- Quote resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ Quote respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -371,15 +397,14 @@ type FilePathAnnotation struct {
Text string `json:"text,required"`
// Always `file_path`.
Type constant.FilePath `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- EndIndex resp.Field
- FilePath resp.Field
- StartIndex resp.Field
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ EndIndex respjson.Field
+ FilePath respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -393,11 +418,10 @@ func (r *FilePathAnnotation) UnmarshalJSON(data []byte) error {
type FilePathAnnotationFilePath struct {
// The ID of the file that was generated.
FileID string `json:"file_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -420,16 +444,15 @@ type FilePathDeltaAnnotation struct {
StartIndex int64 `json:"start_index"`
// The text in the message content that needs to be replaced.
Text string `json:"text"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- EndIndex resp.Field
- FilePath resp.Field
- StartIndex resp.Field
- Text resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ EndIndex respjson.Field
+ FilePath respjson.Field
+ StartIndex respjson.Field
+ Text respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -443,11 +466,10 @@ func (r *FilePathDeltaAnnotation) UnmarshalJSON(data []byte) error {
type FilePathDeltaAnnotationFilePath struct {
// The ID of the file that was generated.
FileID string `json:"file_id"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -468,12 +490,11 @@ type ImageFile struct {
//
// Any of "auto", "low", "high".
Detail ImageFileDetail `json:"detail"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- Detail resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ Detail respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -488,9 +509,9 @@ func (r *ImageFile) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ImageFileParam.IsOverridden()
+// ImageFileParam.Overrides()
func (r ImageFile) ToParam() ImageFileParam {
- return param.OverrideObj[ImageFileParam](r.RawJSON())
+ return param.Override[ImageFileParam](json.RawMessage(r.RawJSON()))
}
// Specifies the detail level of the image if specified by the user. `low` uses
@@ -517,13 +538,13 @@ type ImageFileParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageFileParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ImageFileParam) MarshalJSON() (data []byte, err error) {
type shadow ImageFileParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ImageFileParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// References an image [File](https://platform.openai.com/docs/api-reference/files)
// in the content of a message.
@@ -531,12 +552,11 @@ type ImageFileContentBlock struct {
ImageFile ImageFile `json:"image_file,required"`
// Always `image_file`.
Type constant.ImageFile `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ImageFile resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ImageFile respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -551,9 +571,9 @@ func (r *ImageFileContentBlock) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ImageFileContentBlockParam.IsOverridden()
+// ImageFileContentBlockParam.Overrides()
func (r ImageFileContentBlock) ToParam() ImageFileContentBlockParam {
- return param.OverrideObj[ImageFileContentBlockParam](r.RawJSON())
+ return param.Override[ImageFileContentBlockParam](json.RawMessage(r.RawJSON()))
}
// References an image [File](https://platform.openai.com/docs/api-reference/files)
@@ -569,13 +589,13 @@ type ImageFileContentBlockParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageFileContentBlockParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ImageFileContentBlockParam) MarshalJSON() (data []byte, err error) {
type shadow ImageFileContentBlockParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ImageFileContentBlockParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ImageFileDelta struct {
// Specifies the detail level of the image if specified by the user. `low` uses
@@ -587,12 +607,11 @@ type ImageFileDelta struct {
// in the message content. Set `purpose="vision"` when uploading the File if you
// need to later display the file content.
FileID string `json:"file_id"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Detail resp.Field
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ Detail respjson.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -621,13 +640,12 @@ type ImageFileDeltaBlock struct {
// Always `image_file`.
Type constant.ImageFile `json:"type,required"`
ImageFile ImageFileDelta `json:"image_file"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- ImageFile resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ImageFile respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -647,12 +665,11 @@ type ImageURL struct {
//
// Any of "auto", "low", "high".
Detail ImageURLDetail `json:"detail"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- URL resp.Field
- Detail resp.Field
- ExtraFields map[string]resp.Field
+ URL respjson.Field
+ Detail respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -667,9 +684,9 @@ func (r *ImageURL) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ImageURLParam.IsOverridden()
+// ImageURLParam.Overrides()
func (r ImageURL) ToParam() ImageURLParam {
- return param.OverrideObj[ImageURLParam](r.RawJSON())
+ return param.Override[ImageURLParam](json.RawMessage(r.RawJSON()))
}
// Specifies the detail level of the image. `low` uses fewer tokens, you can opt in
@@ -695,25 +712,24 @@ type ImageURLParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageURLParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ImageURLParam) MarshalJSON() (data []byte, err error) {
type shadow ImageURLParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ImageURLParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// References an image URL in the content of a message.
type ImageURLContentBlock struct {
ImageURL ImageURL `json:"image_url,required"`
// The type of the content part.
Type constant.ImageURL `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ImageURL resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ImageURL respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -728,9 +744,9 @@ func (r *ImageURLContentBlock) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ImageURLContentBlockParam.IsOverridden()
+// ImageURLContentBlockParam.Overrides()
func (r ImageURLContentBlock) ToParam() ImageURLContentBlockParam {
- return param.OverrideObj[ImageURLContentBlockParam](r.RawJSON())
+ return param.Override[ImageURLContentBlockParam](json.RawMessage(r.RawJSON()))
}
// References an image URL in the content of a message.
@@ -745,13 +761,13 @@ type ImageURLContentBlockParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageURLContentBlockParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ImageURLContentBlockParam) MarshalJSON() (data []byte, err error) {
type shadow ImageURLContentBlockParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ImageURLContentBlockParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ImageURLDelta struct {
// Specifies the detail level of the image. `low` uses fewer tokens, you can opt in
@@ -762,12 +778,11 @@ type ImageURLDelta struct {
// The URL of the image, must be a supported image types: jpeg, jpg, png, gif,
// webp.
URL string `json:"url"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Detail resp.Field
- URL resp.Field
- ExtraFields map[string]resp.Field
+ Detail respjson.Field
+ URL respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -795,13 +810,12 @@ type ImageURLDeltaBlock struct {
// Always `image_url`.
Type constant.ImageURL `json:"type,required"`
ImageURL ImageURLDelta `json:"image_url"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- ImageURL resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ImageURL respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -858,24 +872,23 @@ type Message struct {
// The [thread](https://platform.openai.com/docs/api-reference/threads) ID that
// this message belongs to.
ThreadID string `json:"thread_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- AssistantID resp.Field
- Attachments resp.Field
- CompletedAt resp.Field
- Content resp.Field
- CreatedAt resp.Field
- IncompleteAt resp.Field
- IncompleteDetails resp.Field
- Metadata resp.Field
- Object resp.Field
- Role resp.Field
- RunID resp.Field
- Status resp.Field
- ThreadID resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ AssistantID respjson.Field
+ Attachments respjson.Field
+ CompletedAt respjson.Field
+ Content respjson.Field
+ CreatedAt respjson.Field
+ IncompleteAt respjson.Field
+ IncompleteDetails respjson.Field
+ Metadata respjson.Field
+ Object respjson.Field
+ Role respjson.Field
+ RunID respjson.Field
+ Status respjson.Field
+ ThreadID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -891,12 +904,11 @@ type MessageAttachment struct {
FileID string `json:"file_id"`
// The tools to add this file to.
Tools []MessageAttachmentToolUnion `json:"tools"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- Tools resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ Tools respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -908,13 +920,13 @@ func (r *MessageAttachment) UnmarshalJSON(data []byte) error {
}
// MessageAttachmentToolUnion contains all possible properties and values from
-// [CodeInterpreterTool], [MessageAttachmentToolAssistantToolsFileSearchTypeOnly].
+// [CodeInterpreterTool], [MessageAttachmentToolFileSearchTool].
//
// Use the methods beginning with 'As' to cast the union to one of its variants.
type MessageAttachmentToolUnion struct {
Type string `json:"type"`
JSON struct {
- Type resp.Field
+ Type respjson.Field
raw string
} `json:"-"`
}
@@ -924,7 +936,7 @@ func (u MessageAttachmentToolUnion) AsCodeInterpreterTool() (v CodeInterpreterTo
return
}
-func (u MessageAttachmentToolUnion) AsFileSearchTool() (v MessageAttachmentToolAssistantToolsFileSearchTypeOnly) {
+func (u MessageAttachmentToolUnion) AsFileSearchTool() (v MessageAttachmentToolFileSearchTool) {
apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
return
}
@@ -936,21 +948,20 @@ func (r *MessageAttachmentToolUnion) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
-type MessageAttachmentToolAssistantToolsFileSearchTypeOnly struct {
+type MessageAttachmentToolFileSearchTool struct {
// The type of tool being defined: `file_search`
Type constant.FileSearch `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
// Returns the unmodified JSON received from the API
-func (r MessageAttachmentToolAssistantToolsFileSearchTypeOnly) RawJSON() string { return r.JSON.raw }
-func (r *MessageAttachmentToolAssistantToolsFileSearchTypeOnly) UnmarshalJSON(data []byte) error {
+func (r MessageAttachmentToolFileSearchTool) RawJSON() string { return r.JSON.raw }
+func (r *MessageAttachmentToolFileSearchTool) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
@@ -961,11 +972,10 @@ type MessageIncompleteDetails struct {
// Any of "content_filter", "max_tokens", "run_cancelled", "run_expired",
// "run_failed".
Reason string `json:"reason,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Reason resp.Field
- ExtraFields map[string]resp.Field
+ Reason respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1013,26 +1023,37 @@ type MessageContentUnion struct {
// This field is from variant [RefusalContentBlock].
Refusal string `json:"refusal"`
JSON struct {
- ImageFile resp.Field
- Type resp.Field
- ImageURL resp.Field
- Text resp.Field
- Refusal resp.Field
+ ImageFile respjson.Field
+ Type respjson.Field
+ ImageURL respjson.Field
+ Text respjson.Field
+ Refusal respjson.Field
raw string
} `json:"-"`
}
+// anyMessageContent is implemented by each variant of [MessageContentUnion] to add
+// type safety for the return type of [MessageContentUnion.AsAny]
+type anyMessageContent interface {
+ implMessageContentUnion()
+}
+
+func (ImageFileContentBlock) implMessageContentUnion() {}
+func (ImageURLContentBlock) implMessageContentUnion() {}
+func (TextContentBlock) implMessageContentUnion() {}
+func (RefusalContentBlock) implMessageContentUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := MessageContentUnion.AsAny().(type) {
-// case ImageFileContentBlock:
-// case ImageURLContentBlock:
-// case TextContentBlock:
-// case RefusalContentBlock:
+// case openai.ImageFileContentBlock:
+// case openai.ImageURLContentBlock:
+// case openai.TextContentBlock:
+// case openai.RefusalContentBlock:
// default:
// fmt.Errorf("no variant present")
// }
-func (u MessageContentUnion) AsAny() any {
+func (u MessageContentUnion) AsAny() anyMessageContent {
switch u.Type {
case "image_file":
return u.AsImageFile()
@@ -1093,27 +1114,39 @@ type MessageContentDeltaUnion struct {
// This field is from variant [ImageURLDeltaBlock].
ImageURL ImageURLDelta `json:"image_url"`
JSON struct {
- Index resp.Field
- Type resp.Field
- ImageFile resp.Field
- Text resp.Field
- Refusal resp.Field
- ImageURL resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ImageFile respjson.Field
+ Text respjson.Field
+ Refusal respjson.Field
+ ImageURL respjson.Field
raw string
} `json:"-"`
}
+// anyMessageContentDelta is implemented by each variant of
+// [MessageContentDeltaUnion] to add type safety for the return type of
+// [MessageContentDeltaUnion.AsAny]
+type anyMessageContentDelta interface {
+ implMessageContentDeltaUnion()
+}
+
+func (ImageFileDeltaBlock) implMessageContentDeltaUnion() {}
+func (TextDeltaBlock) implMessageContentDeltaUnion() {}
+func (RefusalDeltaBlock) implMessageContentDeltaUnion() {}
+func (ImageURLDeltaBlock) implMessageContentDeltaUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := MessageContentDeltaUnion.AsAny().(type) {
-// case ImageFileDeltaBlock:
-// case TextDeltaBlock:
-// case RefusalDeltaBlock:
-// case ImageURLDeltaBlock:
+// case openai.ImageFileDeltaBlock:
+// case openai.TextDeltaBlock:
+// case openai.RefusalDeltaBlock:
+// case openai.ImageURLDeltaBlock:
// default:
// fmt.Errorf("no variant present")
// }
-func (u MessageContentDeltaUnion) AsAny() any {
+func (u MessageContentDeltaUnion) AsAny() anyMessageContentDelta {
switch u.Type {
case "image_file":
return u.AsImageFile()
@@ -1182,11 +1215,11 @@ type MessageContentPartParamUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u MessageContentPartParamUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u MessageContentPartParamUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[MessageContentPartParamUnion](u.OfImageFile, u.OfImageURL, u.OfText)
+ return param.MarshalUnion(u, u.OfImageFile, u.OfImageURL, u.OfText)
+}
+func (u *MessageContentPartParamUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *MessageContentPartParamUnion) asAny() any {
@@ -1239,21 +1272,9 @@ func (u MessageContentPartParamUnion) GetType() *string {
func init() {
apijson.RegisterUnion[MessageContentPartParamUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ImageFileContentBlockParam{}),
- DiscriminatorValue: "image_file",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ImageURLContentBlockParam{}),
- DiscriminatorValue: "image_url",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(TextContentBlockParam{}),
- DiscriminatorValue: "text",
- },
+ apijson.Discriminator[ImageFileContentBlockParam]("image_file"),
+ apijson.Discriminator[ImageURLContentBlockParam]("image_url"),
+ apijson.Discriminator[TextContentBlockParam]("text"),
)
}
@@ -1261,13 +1282,12 @@ type MessageDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.ThreadMessageDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1286,12 +1306,11 @@ type MessageDelta struct {
//
// Any of "user", "assistant".
Role MessageDeltaRole `json:"role"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Content resp.Field
- Role resp.Field
- ExtraFields map[string]resp.Field
+ Content respjson.Field
+ Role respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1319,13 +1338,12 @@ type MessageDeltaEvent struct {
Delta MessageDelta `json:"delta,required"`
// The object type, which is always `thread.message.delta`.
Object constant.ThreadMessageDelta `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Delta resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Delta respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1341,12 +1359,11 @@ type RefusalContentBlock struct {
Refusal string `json:"refusal,required"`
// Always `refusal`.
Type constant.Refusal `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Refusal resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Refusal respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1364,13 +1381,12 @@ type RefusalDeltaBlock struct {
// Always `refusal`.
Type constant.Refusal `json:"type,required"`
Refusal string `json:"refusal"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- Refusal resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ Refusal respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1385,12 +1401,11 @@ type Text struct {
Annotations []AnnotationUnion `json:"annotations,required"`
// The data that makes up the text.
Value string `json:"value,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Annotations resp.Field
- Value resp.Field
- ExtraFields map[string]resp.Field
+ Annotations respjson.Field
+ Value respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1406,12 +1421,11 @@ type TextContentBlock struct {
Text Text `json:"text,required"`
// Always `text`.
Type constant.Text `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1435,24 +1449,23 @@ type TextContentBlockParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f TextContentBlockParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r TextContentBlockParam) MarshalJSON() (data []byte, err error) {
type shadow TextContentBlockParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *TextContentBlockParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type TextDelta struct {
Annotations []AnnotationDeltaUnion `json:"annotations"`
// The data that makes up the text.
Value string `json:"value"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Annotations resp.Field
- Value resp.Field
- ExtraFields map[string]resp.Field
+ Annotations respjson.Field
+ Value respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -8,7 +8,6 @@ import (
"fmt"
"net/http"
"net/url"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
@@ -16,11 +15,10 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// BetaThreadRunService contains methods and other services that help with
@@ -29,9 +27,12 @@ import (
// Note, unlike clients, this service does not read variables from the environment
// automatically. You should not instantiate this service directly, and instead use
// the [NewBetaThreadRunService] method instead.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
type BetaThreadRunService struct {
Options []option.RequestOption
- Steps BetaThreadRunStepService
+ // Deprecated: The Assistants API is deprecated in favor of the Responses API
+ Steps BetaThreadRunStepService
}
// NewBetaThreadRunService generates a new service that applies the given options
@@ -45,6 +46,8 @@ func NewBetaThreadRunService(opts ...option.RequestOption) (r BetaThreadRunServi
}
// Create a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) New(ctx context.Context, threadID string, params BetaThreadRunNewParams, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -57,17 +60,9 @@ func (r *BetaThreadRunService) New(ctx context.Context, threadID string, params
return
}
-// Create a run and poll until task is completed.
-// Pass 0 to pollIntervalMs to use the default polling interval.
-func (r *BetaThreadRunService) NewAndPoll(ctx context.Context, threadID string, params BetaThreadRunNewParams, pollIntervalMs int, opts ...option.RequestOption) (res *Run, err error) {
- run, err := r.New(ctx, threadID, params, opts...)
- if err != nil {
- return nil, err
- }
- return r.PollStatus(ctx, threadID, run.ID, pollIntervalMs, opts...)
-}
-
// Create a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) NewStreaming(ctx context.Context, threadID string, params BetaThreadRunNewParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) {
var (
raw *http.Response
@@ -85,6 +80,8 @@ func (r *BetaThreadRunService) NewStreaming(ctx context.Context, threadID string
}
// Retrieves a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) Get(ctx context.Context, threadID string, runID string, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -102,6 +99,8 @@ func (r *BetaThreadRunService) Get(ctx context.Context, threadID string, runID s
}
// Modifies a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) Update(ctx context.Context, threadID string, runID string, body BetaThreadRunUpdateParams, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -119,6 +118,8 @@ func (r *BetaThreadRunService) Update(ctx context.Context, threadID string, runI
}
// Returns a list of runs belonging to a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) List(ctx context.Context, threadID string, query BetaThreadRunListParams, opts ...option.RequestOption) (res *pagination.CursorPage[Run], err error) {
var raw *http.Response
opts = append(r.Options[:], opts...)
@@ -141,11 +142,15 @@ func (r *BetaThreadRunService) List(ctx context.Context, threadID string, query
}
// Returns a list of runs belonging to a thread.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) ListAutoPaging(ctx context.Context, threadID string, query BetaThreadRunListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[Run] {
return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, query, opts...))
}
// Cancels a run that is `in_progress`.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) Cancel(ctx context.Context, threadID string, runID string, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -166,6 +171,8 @@ func (r *BetaThreadRunService) Cancel(ctx context.Context, threadID string, runI
// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the
// tool calls once they're all completed. All outputs must be submitted in a single
// request.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) SubmitToolOutputs(ctx context.Context, threadID string, runID string, body BetaThreadRunSubmitToolOutputsParams, opts ...option.RequestOption) (res *Run, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -182,22 +189,12 @@ func (r *BetaThreadRunService) SubmitToolOutputs(ctx context.Context, threadID s
return
}
-// A helper to submit a tool output to a run and poll for a terminal run state.
-// Pass 0 to pollIntervalMs to use the default polling interval.
-// More information on Run lifecycles can be found here:
-// https://platform.openai.com/docs/assistants/how-it-works/runs-and-run-steps
-func (r *BetaThreadRunService) SubmitToolOutputsAndPoll(ctx context.Context, threadID string, runID string, body BetaThreadRunSubmitToolOutputsParams, pollIntervalMs int, opts ...option.RequestOption) (*Run, error) {
- run, err := r.SubmitToolOutputs(ctx, threadID, runID, body, opts...)
- if err != nil {
- return nil, err
- }
- return r.PollStatus(ctx, threadID, run.ID, pollIntervalMs, opts...)
-}
-
// When a run has the `status: "requires_action"` and `required_action.type` is
// `submit_tool_outputs`, this endpoint can be used to submit the outputs from the
// tool calls once they're all completed. All outputs must be submitted in a single
// request.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunService) SubmitToolOutputsStreaming(ctx context.Context, threadID string, runID string, body BetaThreadRunSubmitToolOutputsParams, opts ...option.RequestOption) (stream *ssestream.Stream[AssistantStreamEventUnion]) {
var (
raw *http.Response
@@ -230,13 +227,12 @@ type RequiredActionFunctionToolCall struct {
// The type of tool call the output is required for. For now, this is always
// `function`.
Type constant.Function `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Function resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Function respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -253,12 +249,11 @@ type RequiredActionFunctionToolCallFunction struct {
Arguments string `json:"arguments,required"`
// The name of the function.
Name string `json:"name,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -378,37 +373,36 @@ type Run struct {
Temperature float64 `json:"temperature,nullable"`
// The nucleus sampling value used for this run. If not set, defaults to 1.
TopP float64 `json:"top_p,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- AssistantID resp.Field
- CancelledAt resp.Field
- CompletedAt resp.Field
- CreatedAt resp.Field
- ExpiresAt resp.Field
- FailedAt resp.Field
- IncompleteDetails resp.Field
- Instructions resp.Field
- LastError resp.Field
- MaxCompletionTokens resp.Field
- MaxPromptTokens resp.Field
- Metadata resp.Field
- Model resp.Field
- Object resp.Field
- ParallelToolCalls resp.Field
- RequiredAction resp.Field
- ResponseFormat resp.Field
- StartedAt resp.Field
- Status resp.Field
- ThreadID resp.Field
- ToolChoice resp.Field
- Tools resp.Field
- TruncationStrategy resp.Field
- Usage resp.Field
- Temperature resp.Field
- TopP resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ AssistantID respjson.Field
+ CancelledAt respjson.Field
+ CompletedAt respjson.Field
+ CreatedAt respjson.Field
+ ExpiresAt respjson.Field
+ FailedAt respjson.Field
+ IncompleteDetails respjson.Field
+ Instructions respjson.Field
+ LastError respjson.Field
+ MaxCompletionTokens respjson.Field
+ MaxPromptTokens respjson.Field
+ Metadata respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ ParallelToolCalls respjson.Field
+ RequiredAction respjson.Field
+ ResponseFormat respjson.Field
+ StartedAt respjson.Field
+ Status respjson.Field
+ ThreadID respjson.Field
+ ToolChoice respjson.Field
+ Tools respjson.Field
+ TruncationStrategy respjson.Field
+ Usage respjson.Field
+ Temperature respjson.Field
+ TopP respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -427,11 +421,10 @@ type RunIncompleteDetails struct {
//
// Any of "max_completion_tokens", "max_prompt_tokens".
Reason string `json:"reason"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Reason resp.Field
- ExtraFields map[string]resp.Field
+ Reason respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -450,12 +443,11 @@ type RunLastError struct {
Code string `json:"code,required"`
// A human-readable description of the error.
Message string `json:"message,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -473,12 +465,11 @@ type RunRequiredAction struct {
SubmitToolOutputs RunRequiredActionSubmitToolOutputs `json:"submit_tool_outputs,required"`
// For now, this is always `submit_tool_outputs`.
Type constant.SubmitToolOutputs `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- SubmitToolOutputs resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ SubmitToolOutputs respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -493,11 +484,10 @@ func (r *RunRequiredAction) UnmarshalJSON(data []byte) error {
type RunRequiredActionSubmitToolOutputs struct {
// A list of the relevant tool calls.
ToolCalls []RequiredActionFunctionToolCall `json:"tool_calls,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ToolCalls resp.Field
- ExtraFields map[string]resp.Field
+ ToolCalls respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -521,12 +511,11 @@ type RunTruncationStrategy struct {
// The number of most recent messages from the thread when constructing the context
// for the run.
LastMessages int64 `json:"last_messages,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- LastMessages resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ LastMessages respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -546,13 +535,12 @@ type RunUsage struct {
PromptTokens int64 `json:"prompt_tokens,required"`
// Total number of tokens used (prompt + completion).
TotalTokens int64 `json:"total_tokens,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CompletionTokens resp.Field
- PromptTokens resp.Field
- TotalTokens resp.Field
- ExtraFields map[string]resp.Field
+ CompletionTokens respjson.Field
+ PromptTokens respjson.Field
+ TotalTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -627,7 +615,7 @@ type BetaThreadRunNewParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// The ID of the [Model](https://platform.openai.com/docs/api-reference/models) to
// be used to execute this run. If a value is provided here, it will override the
// model associated with the assistant. If not, the model associated with the
@@ -688,17 +676,16 @@ type BetaThreadRunNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BetaThreadRunNewParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// URLQuery serializes [BetaThreadRunNewParams]'s query parameters as `url.Values`.
-func (r BetaThreadRunNewParams) URLQuery() (v url.Values) {
+func (r BetaThreadRunNewParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -726,23 +713,21 @@ type BetaThreadRunNewParamsAdditionalMessage struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunNewParamsAdditionalMessage) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadRunNewParamsAdditionalMessage) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunNewParamsAdditionalMessage
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunNewParamsAdditionalMessage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[BetaThreadRunNewParamsAdditionalMessage](
- "Role", false, "user", "assistant",
+ "role", "user", "assistant",
)
}
@@ -755,13 +740,11 @@ type BetaThreadRunNewParamsAdditionalMessageContentUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadRunNewParamsAdditionalMessageContentUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadRunNewParamsAdditionalMessageContentUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadRunNewParamsAdditionalMessageContentUnion](u.OfString, u.OfArrayOfContentParts)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts)
+}
+func (u *BetaThreadRunNewParamsAdditionalMessageContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadRunNewParamsAdditionalMessageContentUnion) asAny() any {
@@ -781,15 +764,13 @@ type BetaThreadRunNewParamsAdditionalMessageAttachment struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunNewParamsAdditionalMessageAttachment) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadRunNewParamsAdditionalMessageAttachment) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunNewParamsAdditionalMessageAttachment
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunNewParamsAdditionalMessageAttachment) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -800,13 +781,11 @@ type BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion](u.OfCodeInterpreter, u.OfFileSearch)
+ return param.MarshalUnion(u, u.OfCodeInterpreter, u.OfFileSearch)
+}
+func (u *BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) asAny() any {
@@ -831,37 +810,32 @@ func (u BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion) GetType() *s
func init() {
apijson.RegisterUnion[BetaThreadRunNewParamsAdditionalMessageAttachmentToolUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(CodeInterpreterToolParam{}),
- DiscriminatorValue: "code_interpreter",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch{}),
- DiscriminatorValue: "file_search",
- },
+ apijson.Discriminator[CodeInterpreterToolParam]("code_interpreter"),
+ apijson.Discriminator[BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch]("file_search"),
)
}
-// The property Type is required.
+func NewBetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch() BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch {
+ return BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch{
+ Type: "file_search",
+ }
+}
+
+// This struct has a constant value, construct it with
+// [NewBetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch].
type BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch struct {
// The type of tool being defined: `file_search`
- //
- // This field can be elided, and will marshal its zero value as "file_search".
Type constant.FileSearch `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunNewParamsAdditionalMessageAttachmentToolFileSearch) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Controls for how a thread will be truncated prior to the run. Use this to
// control the intial context window of the run.
@@ -881,19 +855,17 @@ type BetaThreadRunNewParamsTruncationStrategy struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunNewParamsTruncationStrategy) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadRunNewParamsTruncationStrategy) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunNewParamsTruncationStrategy
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunNewParamsTruncationStrategy) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[BetaThreadRunNewParamsTruncationStrategy](
- "Type", false, "auto", "last_messages",
+ "type", "auto", "last_messages",
)
}
@@ -904,18 +876,17 @@ type BetaThreadRunUpdateParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunUpdateParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r BetaThreadRunUpdateParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunUpdateParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunUpdateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadRunListParams struct {
// A cursor for use in pagination. `after` is an object ID that defines your place
@@ -939,13 +910,9 @@ type BetaThreadRunListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [BetaThreadRunListParams]'s query parameters as
// `url.Values`.
-func (r BetaThreadRunListParams) URLQuery() (v url.Values) {
+func (r BetaThreadRunListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -967,16 +934,13 @@ type BetaThreadRunSubmitToolOutputsParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunSubmitToolOutputsParams) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-
func (r BetaThreadRunSubmitToolOutputsParams) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunSubmitToolOutputsParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunSubmitToolOutputsParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type BetaThreadRunSubmitToolOutputsParamsToolOutput struct {
// The output of the tool call to be submitted to continue the run.
@@ -987,12 +951,10 @@ type BetaThreadRunSubmitToolOutputsParamsToolOutput struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunSubmitToolOutputsParamsToolOutput) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r BetaThreadRunSubmitToolOutputsParamsToolOutput) MarshalJSON() (data []byte, err error) {
type shadow BetaThreadRunSubmitToolOutputsParamsToolOutput
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *BetaThreadRunSubmitToolOutputsParamsToolOutput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
@@ -16,7 +16,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
)
@@ -27,6 +27,8 @@ import (
// Note, unlike clients, this service does not read variables from the environment
// automatically. You should not instantiate this service directly, and instead use
// the [NewBetaThreadRunStepService] method instead.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
type BetaThreadRunStepService struct {
Options []option.RequestOption
}
@@ -41,6 +43,8 @@ func NewBetaThreadRunStepService(opts ...option.RequestOption) (r BetaThreadRunS
}
// Retrieves a run step.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunStepService) Get(ctx context.Context, threadID string, runID string, stepID string, query BetaThreadRunStepGetParams, opts ...option.RequestOption) (res *RunStep, err error) {
opts = append(r.Options[:], opts...)
opts = append([]option.RequestOption{option.WithHeader("OpenAI-Beta", "assistants=v2")}, opts...)
@@ -62,6 +66,8 @@ func (r *BetaThreadRunStepService) Get(ctx context.Context, threadID string, run
}
// Returns a list of run steps belonging to a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunStepService) List(ctx context.Context, threadID string, runID string, query BetaThreadRunStepListParams, opts ...option.RequestOption) (res *pagination.CursorPage[RunStep], err error) {
var raw *http.Response
opts = append(r.Options[:], opts...)
@@ -88,6 +94,8 @@ func (r *BetaThreadRunStepService) List(ctx context.Context, threadID string, ru
}
// Returns a list of run steps belonging to a run.
+//
+// Deprecated: The Assistants API is deprecated in favor of the Responses API
func (r *BetaThreadRunStepService) ListAutoPaging(ctx context.Context, threadID string, runID string, query BetaThreadRunStepListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[RunStep] {
return pagination.NewCursorPageAutoPager(r.List(ctx, threadID, runID, query, opts...))
}
@@ -100,13 +108,12 @@ type CodeInterpreterLogs struct {
Type constant.Logs `json:"type,required"`
// The text output from the Code Interpreter tool call.
Logs string `json:"logs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- Logs resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ Logs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -123,13 +130,12 @@ type CodeInterpreterOutputImage struct {
// Always `image`.
Type constant.Image `json:"type,required"`
Image CodeInterpreterOutputImageImage `json:"image"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- Image resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ Image respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -144,11 +150,10 @@ type CodeInterpreterOutputImageImage struct {
// The [file](https://platform.openai.com/docs/api-reference/files) ID of the
// image.
FileID string `json:"file_id"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -168,13 +173,12 @@ type CodeInterpreterToolCall struct {
// The type of tool call. This is always going to be `code_interpreter` for this
// type of tool call.
Type constant.CodeInterpreter `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CodeInterpreter resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CodeInterpreter respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -193,12 +197,11 @@ type CodeInterpreterToolCallCodeInterpreter struct {
// or more items, including text (`logs`) or images (`image`). Each of these are
// represented by a different object type.
Outputs []CodeInterpreterToolCallCodeInterpreterOutputUnion `json:"outputs,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Input resp.Field
- Outputs resp.Field
- ExtraFields map[string]resp.Field
+ Input respjson.Field
+ Outputs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -225,22 +228,34 @@ type CodeInterpreterToolCallCodeInterpreterOutputUnion struct {
// This field is from variant [CodeInterpreterToolCallCodeInterpreterOutputImage].
Image CodeInterpreterToolCallCodeInterpreterOutputImageImage `json:"image"`
JSON struct {
- Logs resp.Field
- Type resp.Field
- Image resp.Field
+ Logs respjson.Field
+ Type respjson.Field
+ Image respjson.Field
raw string
} `json:"-"`
}
+// anyCodeInterpreterToolCallCodeInterpreterOutput is implemented by each variant
+// of [CodeInterpreterToolCallCodeInterpreterOutputUnion] to add type safety for
+// the return type of [CodeInterpreterToolCallCodeInterpreterOutputUnion.AsAny]
+type anyCodeInterpreterToolCallCodeInterpreterOutput interface {
+ implCodeInterpreterToolCallCodeInterpreterOutputUnion()
+}
+
+func (CodeInterpreterToolCallCodeInterpreterOutputLogs) implCodeInterpreterToolCallCodeInterpreterOutputUnion() {
+}
+func (CodeInterpreterToolCallCodeInterpreterOutputImage) implCodeInterpreterToolCallCodeInterpreterOutputUnion() {
+}
+
// Use the following switch statement to find the correct variant
//
// switch variant := CodeInterpreterToolCallCodeInterpreterOutputUnion.AsAny().(type) {
-// case CodeInterpreterToolCallCodeInterpreterOutputLogs:
-// case CodeInterpreterToolCallCodeInterpreterOutputImage:
+// case openai.CodeInterpreterToolCallCodeInterpreterOutputLogs:
+// case openai.CodeInterpreterToolCallCodeInterpreterOutputImage:
// default:
// fmt.Errorf("no variant present")
// }
-func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) AsAny() any {
+func (u CodeInterpreterToolCallCodeInterpreterOutputUnion) AsAny() anyCodeInterpreterToolCallCodeInterpreterOutput {
switch u.Type {
case "logs":
return u.AsLogs()
@@ -273,12 +288,11 @@ type CodeInterpreterToolCallCodeInterpreterOutputLogs struct {
Logs string `json:"logs,required"`
// Always `logs`.
Type constant.Logs `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Logs resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Logs respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -293,12 +307,11 @@ type CodeInterpreterToolCallCodeInterpreterOutputImage struct {
Image CodeInterpreterToolCallCodeInterpreterOutputImageImage `json:"image,required"`
// Always `image`.
Type constant.Image `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Image resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Image respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -313,11 +326,10 @@ type CodeInterpreterToolCallCodeInterpreterOutputImageImage struct {
// The [file](https://platform.openai.com/docs/api-reference/files) ID of the
// image.
FileID string `json:"file_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -339,14 +351,13 @@ type CodeInterpreterToolCallDelta struct {
ID string `json:"id"`
// The Code Interpreter tool call definition.
CodeInterpreter CodeInterpreterToolCallDeltaCodeInterpreter `json:"code_interpreter"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- ID resp.Field
- CodeInterpreter resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ID respjson.Field
+ CodeInterpreter respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -365,12 +376,11 @@ type CodeInterpreterToolCallDeltaCodeInterpreter struct {
// or more items, including text (`logs`) or images (`image`). Each of these are
// represented by a different object type.
Outputs []CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion `json:"outputs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Input resp.Field
- Outputs resp.Field
- ExtraFields map[string]resp.Field
+ Input respjson.Field
+ Outputs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -397,23 +407,34 @@ type CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion struct {
// This field is from variant [CodeInterpreterOutputImage].
Image CodeInterpreterOutputImageImage `json:"image"`
JSON struct {
- Index resp.Field
- Type resp.Field
- Logs resp.Field
- Image resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ Logs respjson.Field
+ Image respjson.Field
raw string
} `json:"-"`
}
+// anyCodeInterpreterToolCallDeltaCodeInterpreterOutput is implemented by each
+// variant of [CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion] to add type
+// safety for the return type of
+// [CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion.AsAny]
+type anyCodeInterpreterToolCallDeltaCodeInterpreterOutput interface {
+ implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion()
+}
+
+func (CodeInterpreterLogs) implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion() {}
+func (CodeInterpreterOutputImage) implCodeInterpreterToolCallDeltaCodeInterpreterOutputUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion.AsAny().(type) {
-// case CodeInterpreterLogs:
-// case CodeInterpreterOutputImage:
+// case openai.CodeInterpreterLogs:
+// case openai.CodeInterpreterOutputImage:
// default:
// fmt.Errorf("no variant present")
// }
-func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) AsAny() any {
+func (u CodeInterpreterToolCallDeltaCodeInterpreterOutputUnion) AsAny() anyCodeInterpreterToolCallDeltaCodeInterpreterOutput {
switch u.Type {
case "logs":
return u.AsLogs()
@@ -448,13 +469,12 @@ type FileSearchToolCall struct {
// The type of tool call. This is always going to be `file_search` for this type of
// tool call.
Type constant.FileSearch `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- FileSearch resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ FileSearch respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -471,12 +491,11 @@ type FileSearchToolCallFileSearch struct {
RankingOptions FileSearchToolCallFileSearchRankingOptions `json:"ranking_options"`
// The results of the file search.
Results []FileSearchToolCallFileSearchResult `json:"results"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- RankingOptions resp.Field
- Results resp.Field
- ExtraFields map[string]resp.Field
+ RankingOptions respjson.Field
+ Results respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -497,12 +516,11 @@ type FileSearchToolCallFileSearchRankingOptions struct {
// The score threshold for the file search. All values must be a floating point
// number between 0 and 1.
ScoreThreshold float64 `json:"score_threshold,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Ranker resp.Field
- ScoreThreshold resp.Field
- ExtraFields map[string]resp.Field
+ Ranker respjson.Field
+ ScoreThreshold respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -525,14 +543,13 @@ type FileSearchToolCallFileSearchResult struct {
// The content of the result that was found. The content is only included if
// requested via the include query parameter.
Content []FileSearchToolCallFileSearchResultContent `json:"content"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileID resp.Field
- FileName resp.Field
- Score resp.Field
- Content resp.Field
- ExtraFields map[string]resp.Field
+ FileID respjson.Field
+ FileName respjson.Field
+ Score respjson.Field
+ Content respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -550,12 +567,11 @@ type FileSearchToolCallFileSearchResultContent struct {
//
// Any of "text".
Type string `json:"type"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -568,7 +584,7 @@ func (r *FileSearchToolCallFileSearchResultContent) UnmarshalJSON(data []byte) e
type FileSearchToolCallDelta struct {
// For now, this is always going to be an empty object.
- FileSearch interface{} `json:"file_search,required"`
+ FileSearch any `json:"file_search,required"`
// The index of the tool call in the tool calls array.
Index int64 `json:"index,required"`
// The type of tool call. This is always going to be `file_search` for this type of
@@ -576,14 +592,13 @@ type FileSearchToolCallDelta struct {
Type constant.FileSearch `json:"type,required"`
// The ID of the tool call object.
ID string `json:"id"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FileSearch resp.Field
- Index resp.Field
- Type resp.Field
- ID resp.Field
- ExtraFields map[string]resp.Field
+ FileSearch respjson.Field
+ Index respjson.Field
+ Type respjson.Field
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -602,13 +617,12 @@ type FunctionToolCall struct {
// The type of tool call. This is always going to be `function` for this type of
// tool call.
Type constant.Function `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Function resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Function respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -629,13 +643,12 @@ type FunctionToolCallFunction struct {
// [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs)
// yet.
Output string `json:"output,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- Output resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ Output respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -656,14 +669,13 @@ type FunctionToolCallDelta struct {
ID string `json:"id"`
// The definition of the function that was called.
Function FunctionToolCallDeltaFunction `json:"function"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- Type resp.Field
- ID resp.Field
- Function resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ID respjson.Field
+ Function respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -684,13 +696,12 @@ type FunctionToolCallDeltaFunction struct {
// [submitted](https://platform.openai.com/docs/api-reference/runs/submitToolOutputs)
// yet.
Output string `json:"output,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- Output resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ Output respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -706,12 +717,11 @@ type MessageCreationStepDetails struct {
MessageCreation MessageCreationStepDetailsMessageCreation `json:"message_creation,required"`
// Always `message_creation`.
Type constant.MessageCreation `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- MessageCreation resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ MessageCreation respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -725,11 +735,10 @@ func (r *MessageCreationStepDetails) UnmarshalJSON(data []byte) error {
type MessageCreationStepDetailsMessageCreation struct {
// The ID of the message that was created by this run step.
MessageID string `json:"message_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- MessageID resp.Field
- ExtraFields map[string]resp.Field
+ MessageID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -791,26 +800,25 @@ type RunStep struct {
// Usage statistics related to the run step. This value will be `null` while the
// run step's status is `in_progress`.
Usage RunStepUsage `json:"usage,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- AssistantID resp.Field
- CancelledAt resp.Field
- CompletedAt resp.Field
- CreatedAt resp.Field
- ExpiredAt resp.Field
- FailedAt resp.Field
- LastError resp.Field
- Metadata resp.Field
- Object resp.Field
- RunID resp.Field
- Status resp.Field
- StepDetails resp.Field
- ThreadID resp.Field
- Type resp.Field
- Usage resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ AssistantID respjson.Field
+ CancelledAt respjson.Field
+ CompletedAt respjson.Field
+ CreatedAt respjson.Field
+ ExpiredAt respjson.Field
+ FailedAt respjson.Field
+ LastError respjson.Field
+ Metadata respjson.Field
+ Object respjson.Field
+ RunID respjson.Field
+ Status respjson.Field
+ StepDetails respjson.Field
+ ThreadID respjson.Field
+ Type respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -830,12 +838,11 @@ type RunStepLastError struct {
Code string `json:"code,required"`
// A human-readable description of the error.
Message string `json:"message,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -872,22 +879,32 @@ type RunStepStepDetailsUnion struct {
// This field is from variant [ToolCallsStepDetails].
ToolCalls []ToolCallUnion `json:"tool_calls"`
JSON struct {
- MessageCreation resp.Field
- Type resp.Field
- ToolCalls resp.Field
+ MessageCreation respjson.Field
+ Type respjson.Field
+ ToolCalls respjson.Field
raw string
} `json:"-"`
}
+// anyRunStepStepDetails is implemented by each variant of
+// [RunStepStepDetailsUnion] to add type safety for the return type of
+// [RunStepStepDetailsUnion.AsAny]
+type anyRunStepStepDetails interface {
+ implRunStepStepDetailsUnion()
+}
+
+func (MessageCreationStepDetails) implRunStepStepDetailsUnion() {}
+func (ToolCallsStepDetails) implRunStepStepDetailsUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := RunStepStepDetailsUnion.AsAny().(type) {
-// case MessageCreationStepDetails:
-// case ToolCallsStepDetails:
+// case openai.MessageCreationStepDetails:
+// case openai.ToolCallsStepDetails:
// default:
// fmt.Errorf("no variant present")
// }
-func (u RunStepStepDetailsUnion) AsAny() any {
+func (u RunStepStepDetailsUnion) AsAny() anyRunStepStepDetails {
switch u.Type {
case "message_creation":
return u.AsMessageCreation()
@@ -931,13 +948,12 @@ type RunStepUsage struct {
PromptTokens int64 `json:"prompt_tokens,required"`
// Total number of tokens used (prompt + completion).
TotalTokens int64 `json:"total_tokens,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CompletionTokens resp.Field
- PromptTokens resp.Field
- TotalTokens resp.Field
- ExtraFields map[string]resp.Field
+ CompletionTokens respjson.Field
+ PromptTokens respjson.Field
+ TotalTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -952,11 +968,10 @@ func (r *RunStepUsage) UnmarshalJSON(data []byte) error {
type RunStepDelta struct {
// The details of the run step.
StepDetails RunStepDeltaStepDetailsUnion `json:"step_details"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- StepDetails resp.Field
- ExtraFields map[string]resp.Field
+ StepDetails respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -981,22 +996,32 @@ type RunStepDeltaStepDetailsUnion struct {
// This field is from variant [ToolCallDeltaObject].
ToolCalls []ToolCallDeltaUnion `json:"tool_calls"`
JSON struct {
- Type resp.Field
- MessageCreation resp.Field
- ToolCalls resp.Field
+ Type respjson.Field
+ MessageCreation respjson.Field
+ ToolCalls respjson.Field
raw string
} `json:"-"`
}
+// anyRunStepDeltaStepDetails is implemented by each variant of
+// [RunStepDeltaStepDetailsUnion] to add type safety for the return type of
+// [RunStepDeltaStepDetailsUnion.AsAny]
+type anyRunStepDeltaStepDetails interface {
+ implRunStepDeltaStepDetailsUnion()
+}
+
+func (RunStepDeltaMessageDelta) implRunStepDeltaStepDetailsUnion() {}
+func (ToolCallDeltaObject) implRunStepDeltaStepDetailsUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := RunStepDeltaStepDetailsUnion.AsAny().(type) {
-// case RunStepDeltaMessageDelta:
-// case ToolCallDeltaObject:
+// case openai.RunStepDeltaMessageDelta:
+// case openai.ToolCallDeltaObject:
// default:
// fmt.Errorf("no variant present")
// }
-func (u RunStepDeltaStepDetailsUnion) AsAny() any {
+func (u RunStepDeltaStepDetailsUnion) AsAny() anyRunStepDeltaStepDetails {
switch u.Type {
case "message_creation":
return u.AsMessageCreation()
@@ -1032,13 +1057,12 @@ type RunStepDeltaEvent struct {
Delta RunStepDelta `json:"delta,required"`
// The object type, which is always `thread.run.step.delta`.
Object constant.ThreadRunStepDelta `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Delta resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Delta respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1054,12 +1078,11 @@ type RunStepDeltaMessageDelta struct {
// Always `message_creation`.
Type constant.MessageCreation `json:"type,required"`
MessageCreation RunStepDeltaMessageDeltaMessageCreation `json:"message_creation"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- MessageCreation resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ MessageCreation respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1073,11 +1096,10 @@ func (r *RunStepDeltaMessageDelta) UnmarshalJSON(data []byte) error {
type RunStepDeltaMessageDeltaMessageCreation struct {
// The ID of the message that was created by this run step.
MessageID string `json:"message_id"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- MessageID resp.Field
- ExtraFields map[string]resp.Field
+ MessageID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1111,25 +1133,35 @@ type ToolCallUnion struct {
// This field is from variant [FunctionToolCall].
Function FunctionToolCallFunction `json:"function"`
JSON struct {
- ID resp.Field
- CodeInterpreter resp.Field
- Type resp.Field
- FileSearch resp.Field
- Function resp.Field
+ ID respjson.Field
+ CodeInterpreter respjson.Field
+ Type respjson.Field
+ FileSearch respjson.Field
+ Function respjson.Field
raw string
} `json:"-"`
}
+// anyToolCall is implemented by each variant of [ToolCallUnion] to add type safety
+// for the return type of [ToolCallUnion.AsAny]
+type anyToolCall interface {
+ implToolCallUnion()
+}
+
+func (CodeInterpreterToolCall) implToolCallUnion() {}
+func (FileSearchToolCall) implToolCallUnion() {}
+func (FunctionToolCall) implToolCallUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := ToolCallUnion.AsAny().(type) {
-// case CodeInterpreterToolCall:
-// case FileSearchToolCall:
-// case FunctionToolCall:
+// case openai.CodeInterpreterToolCall:
+// case openai.FileSearchToolCall:
+// case openai.FunctionToolCall:
// default:
// fmt.Errorf("no variant present")
// }
-func (u ToolCallUnion) AsAny() any {
+func (u ToolCallUnion) AsAny() anyToolCall {
switch u.Type {
case "code_interpreter":
return u.AsCodeInterpreter()
@@ -1178,30 +1210,40 @@ type ToolCallDeltaUnion struct {
// This field is from variant [CodeInterpreterToolCallDelta].
CodeInterpreter CodeInterpreterToolCallDeltaCodeInterpreter `json:"code_interpreter"`
// This field is from variant [FileSearchToolCallDelta].
- FileSearch interface{} `json:"file_search"`
+ FileSearch any `json:"file_search"`
// This field is from variant [FunctionToolCallDelta].
Function FunctionToolCallDeltaFunction `json:"function"`
JSON struct {
- Index resp.Field
- Type resp.Field
- ID resp.Field
- CodeInterpreter resp.Field
- FileSearch resp.Field
- Function resp.Field
+ Index respjson.Field
+ Type respjson.Field
+ ID respjson.Field
+ CodeInterpreter respjson.Field
+ FileSearch respjson.Field
+ Function respjson.Field
raw string
} `json:"-"`
}
+// anyToolCallDelta is implemented by each variant of [ToolCallDeltaUnion] to add
+// type safety for the return type of [ToolCallDeltaUnion.AsAny]
+type anyToolCallDelta interface {
+ implToolCallDeltaUnion()
+}
+
+func (CodeInterpreterToolCallDelta) implToolCallDeltaUnion() {}
+func (FileSearchToolCallDelta) implToolCallDeltaUnion() {}
+func (FunctionToolCallDelta) implToolCallDeltaUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := ToolCallDeltaUnion.AsAny().(type) {
-// case CodeInterpreterToolCallDelta:
-// case FileSearchToolCallDelta:
-// case FunctionToolCallDelta:
+// case openai.CodeInterpreterToolCallDelta:
+// case openai.FileSearchToolCallDelta:
+// case openai.FunctionToolCallDelta:
// default:
// fmt.Errorf("no variant present")
// }
-func (u ToolCallDeltaUnion) AsAny() any {
+func (u ToolCallDeltaUnion) AsAny() anyToolCallDelta {
switch u.Type {
case "code_interpreter":
return u.AsCodeInterpreter()
@@ -1243,12 +1285,11 @@ type ToolCallDeltaObject struct {
// with one of three types of tools: `code_interpreter`, `file_search`, or
// `function`.
ToolCalls []ToolCallDeltaUnion `json:"tool_calls"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ToolCalls resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ToolCalls respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1267,12 +1308,11 @@ type ToolCallsStepDetails struct {
ToolCalls []ToolCallUnion `json:"tool_calls,required"`
// Always `tool_calls`.
Type constant.ToolCalls `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ToolCalls resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ToolCalls respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1295,13 +1335,9 @@ type BetaThreadRunStepGetParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunStepGetParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [BetaThreadRunStepGetParams]'s query parameters as
// `url.Values`.
-func (r BetaThreadRunStepGetParams) URLQuery() (v url.Values) {
+func (r BetaThreadRunStepGetParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -1338,13 +1374,9 @@ type BetaThreadRunStepListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f BetaThreadRunStepListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [BetaThreadRunStepListParams]'s query parameters as
// `url.Values`.
-func (r BetaThreadRunStepListParams) URLQuery() (v url.Values) {
+func (r BetaThreadRunStepListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -4,11 +4,11 @@ package openai
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
@@ -16,11 +16,10 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// ChatCompletionService contains methods and other services that help with
@@ -176,9 +175,25 @@ type ChatCompletion struct {
Model string `json:"model,required"`
// The object type, which is always `chat.completion`.
Object constant.ChatCompletion `json:"object,required"`
- // The service tier used for processing the request.
+ // Specifies the processing type used for serving the request.
//
- // Any of "scale", "default".
+ // - If set to 'auto', then the request will be processed with the service tier
+ // configured in the Project settings. Unless otherwise configured, the Project
+ // will use 'default'.
+ // - If set to 'default', then the requset will be processed with the standard
+ // pricing and performance for the selected model.
+ // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+ // 'priority', then the request will be processed with the corresponding service
+ // tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+ // Priority processing.
+ // - When not set, the default behavior is 'auto'.
+ //
+ // When the `service_tier` parameter is set, the response body will include the
+ // `service_tier` value based on the processing mode actually used to serve the
+ // request. This response value may be different from the value set in the
+ // parameter.
+ //
+ // Any of "auto", "default", "flex", "scale", "priority".
ServiceTier ChatCompletionServiceTier `json:"service_tier,nullable"`
// This fingerprint represents the backend configuration that the model runs with.
//
@@ -187,18 +202,17 @@ type ChatCompletion struct {
SystemFingerprint string `json:"system_fingerprint"`
// Usage statistics for the completion request.
Usage CompletionUsage `json:"usage"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Choices resp.Field
- Created resp.Field
- Model resp.Field
- Object resp.Field
- ServiceTier resp.Field
- SystemFingerprint resp.Field
- Usage resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Choices respjson.Field
+ Created respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ ServiceTier respjson.Field
+ SystemFingerprint respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -225,14 +239,13 @@ type ChatCompletionChoice struct {
Logprobs ChatCompletionChoiceLogprobs `json:"logprobs,required"`
// A chat completion message generated by the model.
Message ChatCompletionMessage `json:"message,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FinishReason resp.Field
- Index resp.Field
- Logprobs resp.Field
- Message resp.Field
- ExtraFields map[string]resp.Field
+ FinishReason respjson.Field
+ Index respjson.Field
+ Logprobs respjson.Field
+ Message respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -249,12 +262,11 @@ type ChatCompletionChoiceLogprobs struct {
Content []ChatCompletionTokenLogprob `json:"content,required"`
// A list of message refusal tokens with log probability information.
Refusal []ChatCompletionTokenLogprob `json:"refusal,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Content resp.Field
- Refusal resp.Field
- ExtraFields map[string]resp.Field
+ Content respjson.Field
+ Refusal respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -265,12 +277,31 @@ func (r *ChatCompletionChoiceLogprobs) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
-// The service tier used for processing the request.
+// Specifies the processing type used for serving the request.
+//
+// - If set to 'auto', then the request will be processed with the service tier
+// configured in the Project settings. Unless otherwise configured, the Project
+// will use 'default'.
+// - If set to 'default', then the requset will be processed with the standard
+// pricing and performance for the selected model.
+// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+// 'priority', then the request will be processed with the corresponding service
+// tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+// Priority processing.
+// - When not set, the default behavior is 'auto'.
+//
+// When the `service_tier` parameter is set, the response body will include the
+// `service_tier` value based on the processing mode actually used to serve the
+// request. This response value may be different from the value set in the
+// parameter.
type ChatCompletionServiceTier string
const (
- ChatCompletionServiceTierScale ChatCompletionServiceTier = "scale"
- ChatCompletionServiceTierDefault ChatCompletionServiceTier = "default"
+ ChatCompletionServiceTierAuto ChatCompletionServiceTier = "auto"
+ ChatCompletionServiceTierDefault ChatCompletionServiceTier = "default"
+ ChatCompletionServiceTierFlex ChatCompletionServiceTier = "flex"
+ ChatCompletionServiceTierScale ChatCompletionServiceTier = "scale"
+ ChatCompletionServiceTierPriority ChatCompletionServiceTier = "priority"
)
// Messages sent by the model in response to user messages.
@@ -302,15 +333,13 @@ type ChatCompletionAssistantMessageParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionAssistantMessageParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionAssistantMessageParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionAssistantMessageParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionAssistantMessageParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Data about a previous audio response from the model.
// [Learn more](https://platform.openai.com/docs/guides/audio).
@@ -322,15 +351,13 @@ type ChatCompletionAssistantMessageParamAudio struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionAssistantMessageParamAudio) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionAssistantMessageParamAudio) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionAssistantMessageParamAudio
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionAssistantMessageParamAudio) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -341,13 +368,11 @@ type ChatCompletionAssistantMessageParamContentUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ChatCompletionAssistantMessageParamContentUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u ChatCompletionAssistantMessageParamContentUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ChatCompletionAssistantMessageParamContentUnion](u.OfString, u.OfArrayOfContentParts)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts)
+}
+func (u *ChatCompletionAssistantMessageParamContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ChatCompletionAssistantMessageParamContentUnion) asAny() any {
@@ -368,13 +393,11 @@ type ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion](u.OfText, u.OfRefusal)
+ return param.MarshalUnion(u, u.OfText, u.OfRefusal)
+}
+func (u *ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) asAny() any {
@@ -415,16 +438,8 @@ func (u ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion) GetTy
func init() {
apijson.RegisterUnion[ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartTextParam{}),
- DiscriminatorValue: "text",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartRefusalParam{}),
- DiscriminatorValue: "refusal",
- },
+ apijson.Discriminator[ChatCompletionContentPartTextParam]("text"),
+ apijson.Discriminator[ChatCompletionContentPartRefusalParam]("refusal"),
)
}
@@ -445,15 +460,13 @@ type ChatCompletionAssistantMessageParamFunctionCall struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionAssistantMessageParamFunctionCall) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionAssistantMessageParamFunctionCall) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionAssistantMessageParamFunctionCall
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionAssistantMessageParamFunctionCall) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// If the audio output modality is requested, this object contains data about the
// audio response from the model.
@@ -469,14 +482,13 @@ type ChatCompletionAudio struct {
ExpiresAt int64 `json:"expires_at,required"`
// Transcript of the audio generated by the model.
Transcript string `json:"transcript,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Data resp.Field
- ExpiresAt resp.Field
- Transcript resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Data respjson.Field
+ ExpiresAt respjson.Field
+ Transcript respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -496,23 +508,21 @@ type ChatCompletionAudioParam struct {
// Specifies the output audio format. Must be one of `wav`, `mp3`, `flac`, `opus`,
// or `pcm16`.
//
- // Any of "wav", "mp3", "flac", "opus", "pcm16".
+ // Any of "wav", "aac", "mp3", "flac", "opus", "pcm16".
Format ChatCompletionAudioParamFormat `json:"format,omitzero,required"`
// The voice the model uses to respond. Supported voices are `alloy`, `ash`,
- // `ballad`, `coral`, `echo`, `sage`, and `shimmer`.
- //
- // Any of "alloy", "ash", "ballad", "coral", "echo", "sage", "shimmer", "verse".
+ // `ballad`, `coral`, `echo`, `fable`, `nova`, `onyx`, `sage`, and `shimmer`.
Voice ChatCompletionAudioParamVoice `json:"voice,omitzero,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionAudioParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ChatCompletionAudioParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionAudioParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionAudioParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Specifies the output audio format. Must be one of `wav`, `mp3`, `flac`, `opus`,
// or `pcm16`.
@@ -520,6 +530,7 @@ type ChatCompletionAudioParamFormat string
const (
ChatCompletionAudioParamFormatWAV ChatCompletionAudioParamFormat = "wav"
+ ChatCompletionAudioParamFormatAAC ChatCompletionAudioParamFormat = "aac"
ChatCompletionAudioParamFormatMP3 ChatCompletionAudioParamFormat = "mp3"
ChatCompletionAudioParamFormatFLAC ChatCompletionAudioParamFormat = "flac"
ChatCompletionAudioParamFormatOpus ChatCompletionAudioParamFormat = "opus"
@@ -527,7 +538,7 @@ const (
)
// The voice the model uses to respond. Supported voices are `alloy`, `ash`,
-// `ballad`, `coral`, `echo`, `sage`, and `shimmer`.
+// `ballad`, `coral`, `echo`, `fable`, `nova`, `onyx`, `sage`, and `shimmer`.
type ChatCompletionAudioParamVoice string
const (
@@ -536,6 +547,9 @@ const (
ChatCompletionAudioParamVoiceBallad ChatCompletionAudioParamVoice = "ballad"
ChatCompletionAudioParamVoiceCoral ChatCompletionAudioParamVoice = "coral"
ChatCompletionAudioParamVoiceEcho ChatCompletionAudioParamVoice = "echo"
+ ChatCompletionAudioParamVoiceFable ChatCompletionAudioParamVoice = "fable"
+ ChatCompletionAudioParamVoiceOnyx ChatCompletionAudioParamVoice = "onyx"
+ ChatCompletionAudioParamVoiceNova ChatCompletionAudioParamVoice = "nova"
ChatCompletionAudioParamVoiceSage ChatCompletionAudioParamVoice = "sage"
ChatCompletionAudioParamVoiceShimmer ChatCompletionAudioParamVoice = "shimmer"
ChatCompletionAudioParamVoiceVerse ChatCompletionAudioParamVoice = "verse"
@@ -558,9 +572,25 @@ type ChatCompletionChunk struct {
Model string `json:"model,required"`
// The object type, which is always `chat.completion.chunk`.
Object constant.ChatCompletionChunk `json:"object,required"`
- // The service tier used for processing the request.
+ // Specifies the processing type used for serving the request.
+ //
+ // - If set to 'auto', then the request will be processed with the service tier
+ // configured in the Project settings. Unless otherwise configured, the Project
+ // will use 'default'.
+ // - If set to 'default', then the requset will be processed with the standard
+ // pricing and performance for the selected model.
+ // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+ // 'priority', then the request will be processed with the corresponding service
+ // tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+ // Priority processing.
+ // - When not set, the default behavior is 'auto'.
+ //
+ // When the `service_tier` parameter is set, the response body will include the
+ // `service_tier` value based on the processing mode actually used to serve the
+ // request. This response value may be different from the value set in the
+ // parameter.
//
- // Any of "scale", "default".
+ // Any of "auto", "default", "flex", "scale", "priority".
ServiceTier ChatCompletionChunkServiceTier `json:"service_tier,nullable"`
// This fingerprint represents the backend configuration that the model runs with.
// Can be used in conjunction with the `seed` request parameter to understand when
@@ -574,18 +604,17 @@ type ChatCompletionChunk struct {
// **NOTE:** If the stream is interrupted or cancelled, you may not receive the
// final usage chunk which contains the total token usage for the request.
Usage CompletionUsage `json:"usage,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Choices resp.Field
- Created resp.Field
- Model resp.Field
- Object resp.Field
- ServiceTier resp.Field
- SystemFingerprint resp.Field
- Usage resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Choices respjson.Field
+ Created respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ ServiceTier respjson.Field
+ SystemFingerprint respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -612,14 +641,13 @@ type ChatCompletionChunkChoice struct {
Index int64 `json:"index,required"`
// Log probability information for the choice.
Logprobs ChatCompletionChunkChoiceLogprobs `json:"logprobs,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Delta resp.Field
- FinishReason resp.Field
- Index resp.Field
- Logprobs resp.Field
- ExtraFields map[string]resp.Field
+ Delta respjson.Field
+ FinishReason respjson.Field
+ Index respjson.Field
+ Logprobs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -646,15 +674,14 @@ type ChatCompletionChunkChoiceDelta struct {
// Any of "developer", "system", "user", "assistant", "tool".
Role string `json:"role"`
ToolCalls []ChatCompletionChunkChoiceDeltaToolCall `json:"tool_calls"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Content resp.Field
- FunctionCall resp.Field
- Refusal resp.Field
- Role resp.Field
- ToolCalls resp.Field
- ExtraFields map[string]resp.Field
+ Content respjson.Field
+ FunctionCall respjson.Field
+ Refusal respjson.Field
+ Role respjson.Field
+ ToolCalls respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -677,12 +704,11 @@ type ChatCompletionChunkChoiceDeltaFunctionCall struct {
Arguments string `json:"arguments"`
// The name of the function to call.
Name string `json:"name"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -702,14 +728,13 @@ type ChatCompletionChunkChoiceDeltaToolCall struct {
//
// Any of "function".
Type string `json:"type"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Index resp.Field
- ID resp.Field
- Function resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Index respjson.Field
+ ID respjson.Field
+ Function respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -728,12 +753,11 @@ type ChatCompletionChunkChoiceDeltaToolCallFunction struct {
Arguments string `json:"arguments"`
// The name of the function to call.
Name string `json:"name"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -750,12 +774,11 @@ type ChatCompletionChunkChoiceLogprobs struct {
Content []ChatCompletionTokenLogprob `json:"content,required"`
// A list of message refusal tokens with log probability information.
Refusal []ChatCompletionTokenLogprob `json:"refusal,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Content resp.Field
- Refusal resp.Field
- ExtraFields map[string]resp.Field
+ Content respjson.Field
+ Refusal respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -766,12 +789,31 @@ func (r *ChatCompletionChunkChoiceLogprobs) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
-// The service tier used for processing the request.
+// Specifies the processing type used for serving the request.
+//
+// - If set to 'auto', then the request will be processed with the service tier
+// configured in the Project settings. Unless otherwise configured, the Project
+// will use 'default'.
+// - If set to 'default', then the requset will be processed with the standard
+// pricing and performance for the selected model.
+// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+// 'priority', then the request will be processed with the corresponding service
+// tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+// Priority processing.
+// - When not set, the default behavior is 'auto'.
+//
+// When the `service_tier` parameter is set, the response body will include the
+// `service_tier` value based on the processing mode actually used to serve the
+// request. This response value may be different from the value set in the
+// parameter.
type ChatCompletionChunkServiceTier string
const (
- ChatCompletionChunkServiceTierScale ChatCompletionChunkServiceTier = "scale"
- ChatCompletionChunkServiceTierDefault ChatCompletionChunkServiceTier = "default"
+ ChatCompletionChunkServiceTierAuto ChatCompletionChunkServiceTier = "auto"
+ ChatCompletionChunkServiceTierDefault ChatCompletionChunkServiceTier = "default"
+ ChatCompletionChunkServiceTierFlex ChatCompletionChunkServiceTier = "flex"
+ ChatCompletionChunkServiceTierScale ChatCompletionChunkServiceTier = "scale"
+ ChatCompletionChunkServiceTierPriority ChatCompletionChunkServiceTier = "priority"
)
func TextContentPart(text string) ChatCompletionContentPartUnionParam {
@@ -809,13 +851,11 @@ type ChatCompletionContentPartUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ChatCompletionContentPartUnionParam) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u ChatCompletionContentPartUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ChatCompletionContentPartUnionParam](u.OfText, u.OfImageURL, u.OfInputAudio, u.OfFile)
+ return param.MarshalUnion(u, u.OfText, u.OfImageURL, u.OfInputAudio, u.OfFile)
+}
+func (u *ChatCompletionContentPartUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ChatCompletionContentPartUnionParam) asAny() any {
@@ -880,26 +920,10 @@ func (u ChatCompletionContentPartUnionParam) GetType() *string {
func init() {
apijson.RegisterUnion[ChatCompletionContentPartUnionParam](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartTextParam{}),
- DiscriminatorValue: "text",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartImageParam{}),
- DiscriminatorValue: "image_url",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartInputAudioParam{}),
- DiscriminatorValue: "input_audio",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ChatCompletionContentPartFileParam{}),
- DiscriminatorValue: "file",
- },
+ apijson.Discriminator[ChatCompletionContentPartTextParam]("text"),
+ apijson.Discriminator[ChatCompletionContentPartImageParam]("image_url"),
+ apijson.Discriminator[ChatCompletionContentPartInputAudioParam]("input_audio"),
+ apijson.Discriminator[ChatCompletionContentPartFileParam]("file"),
)
}
@@ -916,15 +940,13 @@ type ChatCompletionContentPartFileParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartFileParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartFileParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartFileParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartFileParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ChatCompletionContentPartFileFileParam struct {
// The base64 encoded file data, used when passing the file to the model as a
@@ -937,15 +959,13 @@ type ChatCompletionContentPartFileFileParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartFileFileParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartFileFileParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartFileFileParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartFileFileParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Learn about [image inputs](https://platform.openai.com/docs/guides/vision).
//
@@ -959,15 +979,13 @@ type ChatCompletionContentPartImageParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartImageParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartImageParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartImageParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartImageParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The property URL is required.
type ChatCompletionContentPartImageImageURLParam struct {
@@ -981,19 +999,17 @@ type ChatCompletionContentPartImageImageURLParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartImageImageURLParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartImageImageURLParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartImageImageURLParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartImageImageURLParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[ChatCompletionContentPartImageImageURLParam](
- "Detail", false, "auto", "low", "high",
+ "detail", "auto", "low", "high",
)
}
@@ -1009,15 +1025,13 @@ type ChatCompletionContentPartInputAudioParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartInputAudioParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartInputAudioParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartInputAudioParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartInputAudioParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The properties Data, Format are required.
type ChatCompletionContentPartInputAudioInputAudioParam struct {
@@ -1030,19 +1044,17 @@ type ChatCompletionContentPartInputAudioInputAudioParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartInputAudioInputAudioParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartInputAudioInputAudioParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartInputAudioInputAudioParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartInputAudioInputAudioParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[ChatCompletionContentPartInputAudioInputAudioParam](
- "Format", false, "wav", "mp3",
+ "format", "wav", "mp3",
)
}
@@ -1057,15 +1069,13 @@ type ChatCompletionContentPartRefusalParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartRefusalParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartRefusalParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartRefusalParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartRefusalParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Learn about
// [text inputs](https://platform.openai.com/docs/guides/text-generation).
@@ -1081,15 +1091,13 @@ type ChatCompletionContentPartTextParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionContentPartTextParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionContentPartTextParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionContentPartTextParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionContentPartTextParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ChatCompletionDeleted struct {
// The ID of the chat completion that was deleted.
@@ -1098,13 +1106,12 @@ type ChatCompletionDeleted struct {
Deleted bool `json:"deleted,required"`
// The type of object being deleted.
Object constant.ChatCompletionDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1133,15 +1140,13 @@ type ChatCompletionDeveloperMessageParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionDeveloperMessageParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionDeveloperMessageParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionDeveloperMessageParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionDeveloperMessageParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -1152,13 +1157,11 @@ type ChatCompletionDeveloperMessageParamContentUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ChatCompletionDeveloperMessageParamContentUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u ChatCompletionDeveloperMessageParamContentUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ChatCompletionDeveloperMessageParamContentUnion](u.OfString, u.OfArrayOfContentParts)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfContentParts)
+}
+func (u *ChatCompletionDeveloperMessageParamContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ChatCompletionDeveloperMessageParamContentUnion) asAny() any {
@@ -1180,15 +1183,13 @@ type ChatCompletionFunctionCallOptionParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionFunctionCallOptionParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionFunctionCallOptionParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionFunctionCallOptionParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionFunctionCallOptionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Deprecated: deprecated
//
@@ -1205,15 +1206,13 @@ type ChatCompletionFunctionMessageParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionFunctionMessageParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ChatCompletionFunctionMessageParam) MarshalJSON() (data []byte, err error) {
type shadow ChatCompletionFunctionMessageParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ChatCompletionFunctionMessageParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// A chat completion message generated by the model.
type ChatCompletionMessage struct {
@@ -1237,17 +1236,16 @@ type ChatCompletionMessage struct {
FunctionCall ChatCompletionMessageFunctionCall `json:"function_call"`
// The tool calls generated by the model, such as function calls.
ToolCalls []ChatCompletionMessageToolCall `json:"tool_calls"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Content resp.Field
- Refusal resp.Field
- Role resp.Field
- Annotations resp.Field
- Audio resp.Field
- FunctionCall resp.Field
- ToolCalls resp.Field
- ExtraFields map[string]resp.Field
+ Content respjson.Field
+ Refusal respjson.Field
+ Role respjson.Field
+ Annotations respjson.Field
+ Audio respjson.Field
+ FunctionCall respjson.Field
+ ToolCalls respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1300,12 +1298,11 @@ type ChatCompletionMessageAnnotation struct {
Type constant.URLCitation `json:"type,required"`
// A URL citation when using web search.
URLCitation ChatCompletionMessageAnnotationURLCitation `json:"url_citation,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- URLCitation resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ URLCitation respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1326,14 +1323,13 @@ type ChatCompletionMessageAnnotationURLCitation struct {
Title string `json:"title,required"`
// The URL of the web resource.
URL string `json:"url,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- EndIndex resp.Field
- StartIndex resp.Field
- Title resp.Field
- URL resp.Field
- ExtraFields map[string]resp.Field
+ EndIndex respjson.Field
+ StartIndex respjson.Field
+ Title respjson.Field
+ URL respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1356,12 +1352,11 @@ type ChatCompletionMessageFunctionCall struct {
Arguments string `json:"arguments,required"`
// The name of the function to call.
Name string `json:"name,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Arguments resp.Field
- Name resp.Field
- ExtraFields map[string]resp.Field
+ Arguments respjson.Field
+ Name respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -1448,17 +1443,17 @@ type ChatCompletionMessageParamUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ChatCompletionMessageParamUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u ChatCompletionMessageParamUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ChatCompletionMessageParamUnion](u.OfDeveloper,
+ return param.MarshalUnion(u, u.OfDeveloper,
u.OfSystem,
u.OfUser,
u.OfAssistant,
u.OfTool,
u.OfFunction)
}
+func (u *ChatCompletionMessageParamUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
func (u *ChatCompletionMessageParamUnion) asAny() any {
if !param.IsOmitted(u.OfDeveloper) {
@@ -1495,7 +1490,7 @@ func (u ChatCompletionMessageParamUnion) GetFunctionCall() *ChatCompletionAssist
// Returns a pointer to the underlying variant's property, if present.
func (u ChatCompletionMessageParamUnion) GetRefusal() *string {
- if vt := u.OfAssistant; vt != nil && vt.Refusal.IsPresent() {
+ if vt := u.OfAssistant; vt != nil && vt.Refusal.Valid() {
return &vt.Refusal.Value
}
return nil
@@ -1537,13 +1532,13 @@ func (u ChatCompletionMessageParamUnion) GetRole() *string {
// Returns a pointer to the underlying variant's property, if present.
func (u ChatCompletionMessageParamUnion) GetName() *string {
- if vt := u.OfDeveloper; vt != nil && vt.Name.IsPresent() {
+ if vt := u.OfDeveloper; vt != nil && vt.Name.Valid() {
return &vt.Name.Value
- } else if vt := u.OfSystem; vt != nil && vt.Name.IsPresent() {
+ } else if vt := u.OfSystem; vt != nil && vt.Name.Valid() {
return &vt.Name.Value
- } else if vt := u.OfUser; vt != nil && vt.Name.IsPresent() {
+ } else if vt := u.OfUser; vt != nil && vt.Name.Valid() {
return &vt.Name.Value
- } else if vt := u.OfAssistant; vt != nil && vt.Name.IsPresent() {
+ } else if vt := u.OfAssistant; vt != nil && vt.Name.Valid() {
return &vt.Name.Value
} else if vt := u.OfFunction; vt != nil {
return (*string)(&vt.Name)
@@ -1556,32 +1551,25 @@ func (u ChatCompletionMessageParamUnion) GetName() *string {
// Or use AsAny() to get the underlying value
func (u ChatCompletionMessageParamUnion) GetContent() (res chatCompletionMessageParamUnionContent) {
if vt := u.OfDeveloper; vt != nil {
- res.ofChatCompletionDeveloperMessageContent = &vt.Content
+ res.any = vt.Content.asAny()
} else if vt := u.OfSystem; vt != nil {
- res.ofChatCompletionSystemMessageContent = &vt.Content
+ res.any = vt.Content.asAny()
} else if vt := u.OfUser; vt != nil {
- res.ofChatCompletionUserMessageContent = &vt.Content
+ res.any = vt.Content.asAny()
} else if vt := u.OfAssistant; vt != nil {
- res.ofChatCompletionAssistantMessageContent = &vt.Content
+ res.any = vt.Content.asAny()
} else if vt := u.OfTool; vt != nil {
- res.ofChatCompletionToolMessageContent = &vt.Content
- } else if vt := u.OfFunction; vt != nil && vt.Content.IsPresent() {
- res.ofString = &vt.Content.Value
+ res.any = vt.Content.asAny()
+ } else if vt := u.OfFunction; vt != nil && vt.Content.Valid() {
+ res.any = &vt.Content.Value
}
return
}
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type chatCompletionMessageParamUnionContent struct {
- ofChatCompletionDeveloperMessageContent *ChatCompletionDeveloperMessageParamContentUnion
- ofChatCompletionSystemMessageContent *ChatCompletionSystemMessageParamContentUnion
- ofChatCompletionUserMessageContent *ChatCompletionUserMessageParamContentUnion
- ofChatCompletionAssistantMessageContent *ChatCompletionAssistantMessageParamContentUnion
- ofChatCompletionToolMessageContent *ChatCompletionToolMessageParamContentUnion
- ofString *string
-}
+// Can have the runtime types [*string], [_[]ChatCompletionContentPartTextParam],
+// [_[]ChatCompletionContentPartUnionParam],
+// [\*[]ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion]
+type chatCompletionMessageParamUnionContent struct{ any }
// Use the following switch statement to get the type of the union:
//
@@ -77,13 +77,9 @@ type ChatCompletionMessageListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ChatCompletionMessageListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [ChatCompletionMessageListParams]'s query parameters as
// `url.Values`.
-func (r ChatCompletionMessageListParams) URLQuery() (v url.Values) {
+func (r ChatCompletionMessageListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -10,6 +10,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/responses"
+ "github.com/openai/openai-go/webhooks"
)
// Client creates a struct with services and top level methods that help with
@@ -26,17 +27,24 @@ type Client struct {
Moderations ModerationService
Models ModelService
FineTuning FineTuningService
+ Graders GraderService
VectorStores VectorStoreService
+ Webhooks webhooks.WebhookService
Beta BetaService
Batches BatchService
Uploads UploadService
Responses responses.ResponseService
+ Containers ContainerService
}
// DefaultClientOptions read from the environment (OPENAI_API_KEY, OPENAI_ORG_ID,
-// OPENAI_PROJECT_ID). This should be used to initialize new clients.
+// OPENAI_PROJECT_ID, OPENAI_WEBHOOK_SECRET, OPENAI_BASE_URL). This should be used
+// to initialize new clients.
func DefaultClientOptions() []option.RequestOption {
defaults := []option.RequestOption{option.WithEnvironmentProduction()}
+ if o, ok := os.LookupEnv("OPENAI_BASE_URL"); ok {
+ defaults = append(defaults, option.WithBaseURL(o))
+ }
if o, ok := os.LookupEnv("OPENAI_API_KEY"); ok {
defaults = append(defaults, option.WithAPIKey(o))
}
@@ -46,13 +54,17 @@ func DefaultClientOptions() []option.RequestOption {
if o, ok := os.LookupEnv("OPENAI_PROJECT_ID"); ok {
defaults = append(defaults, option.WithProject(o))
}
+ if o, ok := os.LookupEnv("OPENAI_WEBHOOK_SECRET"); ok {
+ defaults = append(defaults, option.WithWebhookSecret(o))
+ }
return defaults
}
// NewClient generates a new client with the default option read from the
-// environment (OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID). The option
-// passed in as arguments are applied after these default arguments, and all option
-// will be passed down to the services and requests that this client makes.
+// environment (OPENAI_API_KEY, OPENAI_ORG_ID, OPENAI_PROJECT_ID,
+// OPENAI_WEBHOOK_SECRET, OPENAI_BASE_URL). The option passed in as arguments are
+// applied after these default arguments, and all option will be passed down to the
+// services and requests that this client makes.
func NewClient(opts ...option.RequestOption) (r Client) {
opts = append(DefaultClientOptions(), opts...)
@@ -67,11 +79,14 @@ func NewClient(opts ...option.RequestOption) (r Client) {
r.Moderations = NewModerationService(opts...)
r.Models = NewModelService(opts...)
r.FineTuning = NewFineTuningService(opts...)
+ r.Graders = NewGraderService(opts...)
r.VectorStores = NewVectorStoreService(opts...)
+ r.Webhooks = webhooks.NewWebhookService(opts...)
r.Beta = NewBetaService(opts...)
r.Batches = NewBatchService(opts...)
r.Uploads = NewUploadService(opts...)
r.Responses = responses.NewResponseService(opts...)
+ r.Containers = NewContainerService(opts...)
return
}
@@ -107,40 +122,40 @@ func NewClient(opts ...option.RequestOption) (r Client) {
//
// For even greater flexibility, see [option.WithResponseInto] and
// [option.WithResponseBodyInto].
-func (r *Client) Execute(ctx context.Context, method string, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Execute(ctx context.Context, method string, path string, params any, res any, opts ...option.RequestOption) error {
opts = append(r.Options, opts...)
return requestconfig.ExecuteNewRequest(ctx, method, path, params, res, opts...)
}
// Get makes a GET request with the given URL, params, and optionally deserializes
// to a response. See [Execute] documentation on the params and response.
-func (r *Client) Get(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Get(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error {
return r.Execute(ctx, http.MethodGet, path, params, res, opts...)
}
// Post makes a POST request with the given URL, params, and optionally
// deserializes to a response. See [Execute] documentation on the params and
// response.
-func (r *Client) Post(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Post(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error {
return r.Execute(ctx, http.MethodPost, path, params, res, opts...)
}
// Put makes a PUT request with the given URL, params, and optionally deserializes
// to a response. See [Execute] documentation on the params and response.
-func (r *Client) Put(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Put(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error {
return r.Execute(ctx, http.MethodPut, path, params, res, opts...)
}
// Patch makes a PATCH request with the given URL, params, and optionally
// deserializes to a response. See [Execute] documentation on the params and
// response.
-func (r *Client) Patch(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Patch(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error {
return r.Execute(ctx, http.MethodPatch, path, params, res, opts...)
}
// Delete makes a DELETE request with the given URL, params, and optionally
// deserializes to a response. See [Execute] documentation on the params and
// response.
-func (r *Client) Delete(ctx context.Context, path string, params interface{}, res interface{}, opts ...option.RequestOption) error {
+func (r *Client) Delete(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error {
return r.Execute(ctx, http.MethodDelete, path, params, res, opts...)
}
@@ -10,7 +10,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared/constant"
)
@@ -75,17 +75,16 @@ type Completion struct {
SystemFingerprint string `json:"system_fingerprint"`
// Usage statistics for the completion request.
Usage CompletionUsage `json:"usage"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Choices resp.Field
- Created resp.Field
- Model resp.Field
- Object resp.Field
- SystemFingerprint resp.Field
- Usage resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Choices respjson.Field
+ Created respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ SystemFingerprint respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -107,14 +106,13 @@ type CompletionChoice struct {
Index int64 `json:"index,required"`
Logprobs CompletionChoiceLogprobs `json:"logprobs,required"`
Text string `json:"text,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FinishReason resp.Field
- Index resp.Field
- Logprobs resp.Field
- Text resp.Field
- ExtraFields map[string]resp.Field
+ FinishReason respjson.Field
+ Index respjson.Field
+ Logprobs respjson.Field
+ Text respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -142,14 +140,13 @@ type CompletionChoiceLogprobs struct {
TokenLogprobs []float64 `json:"token_logprobs"`
Tokens []string `json:"tokens"`
TopLogprobs []map[string]float64 `json:"top_logprobs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- TextOffset resp.Field
- TokenLogprobs resp.Field
- Tokens resp.Field
- TopLogprobs resp.Field
- ExtraFields map[string]resp.Field
+ TextOffset respjson.Field
+ TokenLogprobs respjson.Field
+ Tokens respjson.Field
+ TopLogprobs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -172,15 +169,14 @@ type CompletionUsage struct {
CompletionTokensDetails CompletionUsageCompletionTokensDetails `json:"completion_tokens_details"`
// Breakdown of tokens used in the prompt.
PromptTokensDetails CompletionUsagePromptTokensDetails `json:"prompt_tokens_details"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CompletionTokens resp.Field
- PromptTokens resp.Field
- TotalTokens resp.Field
- CompletionTokensDetails resp.Field
- PromptTokensDetails resp.Field
- ExtraFields map[string]resp.Field
+ CompletionTokens respjson.Field
+ PromptTokens respjson.Field
+ TotalTokens respjson.Field
+ CompletionTokensDetails respjson.Field
+ PromptTokensDetails respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -205,14 +201,13 @@ type CompletionUsageCompletionTokensDetails struct {
// still counted in the total completion tokens for purposes of billing, output,
// and context window limits.
RejectedPredictionTokens int64 `json:"rejected_prediction_tokens"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- AcceptedPredictionTokens resp.Field
- AudioTokens resp.Field
- ReasoningTokens resp.Field
- RejectedPredictionTokens resp.Field
- ExtraFields map[string]resp.Field
+ AcceptedPredictionTokens respjson.Field
+ AudioTokens respjson.Field
+ ReasoningTokens respjson.Field
+ RejectedPredictionTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -229,12 +224,11 @@ type CompletionUsagePromptTokensDetails struct {
AudioTokens int64 `json:"audio_tokens"`
// Cached tokens present in the prompt.
CachedTokens int64 `json:"cached_tokens"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- AudioTokens resp.Field
- CachedTokens resp.Field
- ExtraFields map[string]resp.Field
+ AudioTokens respjson.Field
+ CachedTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -258,7 +252,7 @@ type CompletionNewParams struct {
// see all of your available models, or see our
// [Model overview](https://platform.openai.com/docs/models) for descriptions of
// them.
- Model string `json:"model,omitzero,required"`
+ Model CompletionNewParamsModel `json:"model,omitzero,required"`
// Generates `best_of` completions server-side and returns the "best" (the one with
// the highest log probability per token). Results cannot be streamed.
//
@@ -344,6 +338,8 @@ type CompletionNewParams struct {
// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token
// from being generated.
LogitBias map[string]int64 `json:"logit_bias,omitzero"`
+ // Not supported with latest reasoning models `o3` and `o4-mini`.
+ //
// Up to 4 sequences where the API will stop generating further tokens. The
// returned text will not contain the stop sequence.
Stop CompletionNewParamsStopUnion `json:"stop,omitzero"`
@@ -352,14 +348,26 @@ type CompletionNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f CompletionNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r CompletionNewParams) MarshalJSON() (data []byte, err error) {
type shadow CompletionNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *CompletionNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ID of the model to use. You can use the
+// [List models](https://platform.openai.com/docs/api-reference/models/list) API to
+// see all of your available models, or see our
+// [Model overview](https://platform.openai.com/docs/models) for descriptions of
+// them.
+type CompletionNewParamsModel string
+
+const (
+ CompletionNewParamsModelGPT3_5TurboInstruct CompletionNewParamsModel = "gpt-3.5-turbo-instruct"
+ CompletionNewParamsModelDavinci002 CompletionNewParamsModel = "davinci-002"
+ CompletionNewParamsModelBabbage002 CompletionNewParamsModel = "babbage-002"
+)
// Only one field can be non-zero.
//
@@ -372,11 +380,11 @@ type CompletionNewParamsPromptUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u CompletionNewParamsPromptUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u CompletionNewParamsPromptUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[CompletionNewParamsPromptUnion](u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays)
+}
+func (u *CompletionNewParamsPromptUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *CompletionNewParamsPromptUnion) asAny() any {
@@ -396,23 +404,23 @@ func (u *CompletionNewParamsPromptUnion) asAny() any {
//
// Use [param.IsOmitted] to confirm if a field is set.
type CompletionNewParamsStopUnion struct {
- OfString param.Opt[string] `json:",omitzero,inline"`
- OfCompletionNewsStopArray []string `json:",omitzero,inline"`
+ OfString param.Opt[string] `json:",omitzero,inline"`
+ OfStringArray []string `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u CompletionNewParamsStopUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u CompletionNewParamsStopUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[CompletionNewParamsStopUnion](u.OfString, u.OfCompletionNewsStopArray)
+ return param.MarshalUnion(u, u.OfString, u.OfStringArray)
+}
+func (u *CompletionNewParamsStopUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *CompletionNewParamsStopUnion) asAny() any {
if !param.IsOmitted(u.OfString) {
return &u.OfString.Value
- } else if !param.IsOmitted(u.OfCompletionNewsStopArray) {
- return &u.OfCompletionNewsStopArray
+ } else if !param.IsOmitted(u.OfStringArray) {
+ return &u.OfStringArray
}
return nil
}
@@ -0,0 +1,352 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/internal/apiquery"
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/pagination"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+)
+
+// ContainerService contains methods and other services that help with interacting
+// with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewContainerService] method instead.
+type ContainerService struct {
+ Options []option.RequestOption
+ Files ContainerFileService
+}
+
+// NewContainerService generates a new service that applies the given options to
+// each request. These options are applied after the parent client's options (if
+// there is one), and before any request-specific options.
+func NewContainerService(opts ...option.RequestOption) (r ContainerService) {
+ r = ContainerService{}
+ r.Options = opts
+ r.Files = NewContainerFileService(opts...)
+ return
+}
+
+// Create Container
+func (r *ContainerService) New(ctx context.Context, body ContainerNewParams, opts ...option.RequestOption) (res *ContainerNewResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ path := "containers"
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...)
+ return
+}
+
+// Retrieve Container
+func (r *ContainerService) Get(ctx context.Context, containerID string, opts ...option.RequestOption) (res *ContainerGetResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s", containerID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...)
+ return
+}
+
+// List Containers
+func (r *ContainerService) List(ctx context.Context, query ContainerListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ContainerListResponse], err error) {
+ var raw *http.Response
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...)
+ path := "containers"
+ cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...)
+ if err != nil {
+ return nil, err
+ }
+ err = cfg.Execute()
+ if err != nil {
+ return nil, err
+ }
+ res.SetPageConfig(cfg, raw)
+ return res, nil
+}
+
+// List Containers
+func (r *ContainerService) ListAutoPaging(ctx context.Context, query ContainerListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ContainerListResponse] {
+ return pagination.NewCursorPageAutoPager(r.List(ctx, query, opts...))
+}
+
+// Delete Container
+func (r *ContainerService) Delete(ctx context.Context, containerID string, opts ...option.RequestOption) (err error) {
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s", containerID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, nil, opts...)
+ return
+}
+
+type ContainerNewResponse struct {
+ // Unique identifier for the container.
+ ID string `json:"id,required"`
+ // Unix timestamp (in seconds) when the container was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // Name of the container.
+ Name string `json:"name,required"`
+ // The type of this object.
+ Object string `json:"object,required"`
+ // Status of the container (e.g., active, deleted).
+ Status string `json:"status,required"`
+ // The container will expire after this time period. The anchor is the reference
+ // point for the expiration. The minutes is the number of minutes after the anchor
+ // before the container expires.
+ ExpiresAfter ContainerNewResponseExpiresAfter `json:"expires_after"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Name respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ ExpiresAfter respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerNewResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerNewResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The container will expire after this time period. The anchor is the reference
+// point for the expiration. The minutes is the number of minutes after the anchor
+// before the container expires.
+type ContainerNewResponseExpiresAfter struct {
+ // The reference point for the expiration.
+ //
+ // Any of "last_active_at".
+ Anchor string `json:"anchor"`
+ // The number of minutes after the anchor before the container expires.
+ Minutes int64 `json:"minutes"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Anchor respjson.Field
+ Minutes respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerNewResponseExpiresAfter) RawJSON() string { return r.JSON.raw }
+func (r *ContainerNewResponseExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerGetResponse struct {
+ // Unique identifier for the container.
+ ID string `json:"id,required"`
+ // Unix timestamp (in seconds) when the container was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // Name of the container.
+ Name string `json:"name,required"`
+ // The type of this object.
+ Object string `json:"object,required"`
+ // Status of the container (e.g., active, deleted).
+ Status string `json:"status,required"`
+ // The container will expire after this time period. The anchor is the reference
+ // point for the expiration. The minutes is the number of minutes after the anchor
+ // before the container expires.
+ ExpiresAfter ContainerGetResponseExpiresAfter `json:"expires_after"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Name respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ ExpiresAfter respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerGetResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerGetResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The container will expire after this time period. The anchor is the reference
+// point for the expiration. The minutes is the number of minutes after the anchor
+// before the container expires.
+type ContainerGetResponseExpiresAfter struct {
+ // The reference point for the expiration.
+ //
+ // Any of "last_active_at".
+ Anchor string `json:"anchor"`
+ // The number of minutes after the anchor before the container expires.
+ Minutes int64 `json:"minutes"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Anchor respjson.Field
+ Minutes respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerGetResponseExpiresAfter) RawJSON() string { return r.JSON.raw }
+func (r *ContainerGetResponseExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerListResponse struct {
+ // Unique identifier for the container.
+ ID string `json:"id,required"`
+ // Unix timestamp (in seconds) when the container was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // Name of the container.
+ Name string `json:"name,required"`
+ // The type of this object.
+ Object string `json:"object,required"`
+ // Status of the container (e.g., active, deleted).
+ Status string `json:"status,required"`
+ // The container will expire after this time period. The anchor is the reference
+ // point for the expiration. The minutes is the number of minutes after the anchor
+ // before the container expires.
+ ExpiresAfter ContainerListResponseExpiresAfter `json:"expires_after"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Name respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ ExpiresAfter respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerListResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerListResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The container will expire after this time period. The anchor is the reference
+// point for the expiration. The minutes is the number of minutes after the anchor
+// before the container expires.
+type ContainerListResponseExpiresAfter struct {
+ // The reference point for the expiration.
+ //
+ // Any of "last_active_at".
+ Anchor string `json:"anchor"`
+ // The number of minutes after the anchor before the container expires.
+ Minutes int64 `json:"minutes"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Anchor respjson.Field
+ Minutes respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerListResponseExpiresAfter) RawJSON() string { return r.JSON.raw }
+func (r *ContainerListResponseExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerNewParams struct {
+ // Name of the container to create.
+ Name string `json:"name,required"`
+ // Container expiration time in seconds relative to the 'anchor' time.
+ ExpiresAfter ContainerNewParamsExpiresAfter `json:"expires_after,omitzero"`
+ // IDs of files to copy to the container.
+ FileIDs []string `json:"file_ids,omitzero"`
+ paramObj
+}
+
+func (r ContainerNewParams) MarshalJSON() (data []byte, err error) {
+ type shadow ContainerNewParams
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ContainerNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Container expiration time in seconds relative to the 'anchor' time.
+//
+// The properties Anchor, Minutes are required.
+type ContainerNewParamsExpiresAfter struct {
+ // Time anchor for the expiration time. Currently only 'last_active_at' is
+ // supported.
+ //
+ // Any of "last_active_at".
+ Anchor string `json:"anchor,omitzero,required"`
+ Minutes int64 `json:"minutes,required"`
+ paramObj
+}
+
+func (r ContainerNewParamsExpiresAfter) MarshalJSON() (data []byte, err error) {
+ type shadow ContainerNewParamsExpiresAfter
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ContainerNewParamsExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+func init() {
+ apijson.RegisterFieldValidator[ContainerNewParamsExpiresAfter](
+ "anchor", "last_active_at",
+ )
+}
+
+type ContainerListParams struct {
+ // A cursor for use in pagination. `after` is an object ID that defines your place
+ // in the list. For instance, if you make a list request and receive 100 objects,
+ // ending with obj_foo, your subsequent call can include after=obj_foo in order to
+ // fetch the next page of the list.
+ After param.Opt[string] `query:"after,omitzero" json:"-"`
+ // A limit on the number of objects to be returned. Limit can range between 1 and
+ // 100, and the default is 20.
+ Limit param.Opt[int64] `query:"limit,omitzero" json:"-"`
+ // Sort order by the `created_at` timestamp of the objects. `asc` for ascending
+ // order and `desc` for descending order.
+ //
+ // Any of "asc", "desc".
+ Order ContainerListParamsOrder `query:"order,omitzero" json:"-"`
+ paramObj
+}
+
+// URLQuery serializes [ContainerListParams]'s query parameters as `url.Values`.
+func (r ContainerListParams) URLQuery() (v url.Values, err error) {
+ return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
+ ArrayFormat: apiquery.ArrayQueryFormatBrackets,
+ NestedFormat: apiquery.NestedQueryFormatBrackets,
+ })
+}
+
+// Sort order by the `created_at` timestamp of the objects. `asc` for ascending
+// order and `desc` for descending order.
+type ContainerListParamsOrder string
+
+const (
+ ContainerListParamsOrderAsc ContainerListParamsOrder = "asc"
+ ContainerListParamsOrderDesc ContainerListParamsOrder = "desc"
+)
@@ -0,0 +1,286 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "mime/multipart"
+ "net/http"
+ "net/url"
+
+ "github.com/openai/openai-go/internal/apiform"
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/internal/apiquery"
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/pagination"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+ "github.com/openai/openai-go/shared/constant"
+)
+
+// ContainerFileService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewContainerFileService] method instead.
+type ContainerFileService struct {
+ Options []option.RequestOption
+ Content ContainerFileContentService
+}
+
+// NewContainerFileService generates a new service that applies the given options
+// to each request. These options are applied after the parent client's options (if
+// there is one), and before any request-specific options.
+func NewContainerFileService(opts ...option.RequestOption) (r ContainerFileService) {
+ r = ContainerFileService{}
+ r.Options = opts
+ r.Content = NewContainerFileContentService(opts...)
+ return
+}
+
+// Create a Container File
+//
+// You can send either a multipart/form-data request with the raw file content, or
+// a JSON request with a file ID.
+func (r *ContainerFileService) New(ctx context.Context, containerID string, body ContainerFileNewParams, opts ...option.RequestOption) (res *ContainerFileNewResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s/files", containerID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...)
+ return
+}
+
+// Retrieve Container File
+func (r *ContainerFileService) Get(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (res *ContainerFileGetResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ if fileID == "" {
+ err = errors.New("missing required file_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s/files/%s", containerID, fileID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...)
+ return
+}
+
+// List Container files
+func (r *ContainerFileService) List(ctx context.Context, containerID string, query ContainerFileListParams, opts ...option.RequestOption) (res *pagination.CursorPage[ContainerFileListResponse], err error) {
+ var raw *http.Response
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s/files", containerID)
+ cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodGet, path, query, &res, opts...)
+ if err != nil {
+ return nil, err
+ }
+ err = cfg.Execute()
+ if err != nil {
+ return nil, err
+ }
+ res.SetPageConfig(cfg, raw)
+ return res, nil
+}
+
+// List Container files
+func (r *ContainerFileService) ListAutoPaging(ctx context.Context, containerID string, query ContainerFileListParams, opts ...option.RequestOption) *pagination.CursorPageAutoPager[ContainerFileListResponse] {
+ return pagination.NewCursorPageAutoPager(r.List(ctx, containerID, query, opts...))
+}
+
+// Delete Container File
+func (r *ContainerFileService) Delete(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (err error) {
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithHeader("Accept", "")}, opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ if fileID == "" {
+ err = errors.New("missing required file_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s/files/%s", containerID, fileID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, nil, opts...)
+ return
+}
+
+type ContainerFileNewResponse struct {
+ // Unique identifier for the file.
+ ID string `json:"id,required"`
+ // Size of the file in bytes.
+ Bytes int64 `json:"bytes,required"`
+ // The container this file belongs to.
+ ContainerID string `json:"container_id,required"`
+ // Unix timestamp (in seconds) when the file was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // The type of this object (`container.file`).
+ Object constant.ContainerFile `json:"object,required"`
+ // Path of the file in the container.
+ Path string `json:"path,required"`
+ // Source of the file (e.g., `user`, `assistant`).
+ Source string `json:"source,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ Bytes respjson.Field
+ ContainerID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ Path respjson.Field
+ Source respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerFileNewResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerFileNewResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerFileGetResponse struct {
+ // Unique identifier for the file.
+ ID string `json:"id,required"`
+ // Size of the file in bytes.
+ Bytes int64 `json:"bytes,required"`
+ // The container this file belongs to.
+ ContainerID string `json:"container_id,required"`
+ // Unix timestamp (in seconds) when the file was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // The type of this object (`container.file`).
+ Object constant.ContainerFile `json:"object,required"`
+ // Path of the file in the container.
+ Path string `json:"path,required"`
+ // Source of the file (e.g., `user`, `assistant`).
+ Source string `json:"source,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ Bytes respjson.Field
+ ContainerID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ Path respjson.Field
+ Source respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerFileGetResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerFileGetResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerFileListResponse struct {
+ // Unique identifier for the file.
+ ID string `json:"id,required"`
+ // Size of the file in bytes.
+ Bytes int64 `json:"bytes,required"`
+ // The container this file belongs to.
+ ContainerID string `json:"container_id,required"`
+ // Unix timestamp (in seconds) when the file was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // The type of this object (`container.file`).
+ Object constant.ContainerFile `json:"object,required"`
+ // Path of the file in the container.
+ Path string `json:"path,required"`
+ // Source of the file (e.g., `user`, `assistant`).
+ Source string `json:"source,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ Bytes respjson.Field
+ ContainerID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ Path respjson.Field
+ Source respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ContainerFileListResponse) RawJSON() string { return r.JSON.raw }
+func (r *ContainerFileListResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type ContainerFileNewParams struct {
+ // Name of the file to create.
+ FileID param.Opt[string] `json:"file_id,omitzero"`
+ // The File object (not file name) to be uploaded.
+ File io.Reader `json:"file,omitzero" format:"binary"`
+ paramObj
+}
+
+func (r ContainerFileNewParams) MarshalMultipart() (data []byte, contentType string, err error) {
+ buf := bytes.NewBuffer(nil)
+ writer := multipart.NewWriter(buf)
+ err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
+ if err != nil {
+ writer.Close()
+ return nil, "", err
+ }
+ err = writer.Close()
+ if err != nil {
+ return nil, "", err
+ }
+ return buf.Bytes(), writer.FormDataContentType(), nil
+}
+
+type ContainerFileListParams struct {
+ // A cursor for use in pagination. `after` is an object ID that defines your place
+ // in the list. For instance, if you make a list request and receive 100 objects,
+ // ending with obj_foo, your subsequent call can include after=obj_foo in order to
+ // fetch the next page of the list.
+ After param.Opt[string] `query:"after,omitzero" json:"-"`
+ // A limit on the number of objects to be returned. Limit can range between 1 and
+ // 100, and the default is 20.
+ Limit param.Opt[int64] `query:"limit,omitzero" json:"-"`
+ // Sort order by the `created_at` timestamp of the objects. `asc` for ascending
+ // order and `desc` for descending order.
+ //
+ // Any of "asc", "desc".
+ Order ContainerFileListParamsOrder `query:"order,omitzero" json:"-"`
+ paramObj
+}
+
+// URLQuery serializes [ContainerFileListParams]'s query parameters as
+// `url.Values`.
+func (r ContainerFileListParams) URLQuery() (v url.Values, err error) {
+ return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
+ ArrayFormat: apiquery.ArrayQueryFormatBrackets,
+ NestedFormat: apiquery.NestedQueryFormatBrackets,
+ })
+}
+
+// Sort order by the `created_at` timestamp of the objects. `asc` for ascending
+// order and `desc` for descending order.
+type ContainerFileListParamsOrder string
+
+const (
+ ContainerFileListParamsOrderAsc ContainerFileListParamsOrder = "asc"
+ ContainerFileListParamsOrderDesc ContainerFileListParamsOrder = "desc"
+)
@@ -0,0 +1,49 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+)
+
+// ContainerFileContentService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewContainerFileContentService] method instead.
+type ContainerFileContentService struct {
+ Options []option.RequestOption
+}
+
+// NewContainerFileContentService generates a new service that applies the given
+// options to each request. These options are applied after the parent client's
+// options (if there is one), and before any request-specific options.
+func NewContainerFileContentService(opts ...option.RequestOption) (r ContainerFileContentService) {
+ r = ContainerFileContentService{}
+ r.Options = opts
+ return
+}
+
+// Retrieve Container File Content
+func (r *ContainerFileContentService) Get(ctx context.Context, containerID string, fileID string, opts ...option.RequestOption) (res *http.Response, err error) {
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithHeader("Accept", "application/binary")}, opts...)
+ if containerID == "" {
+ err = errors.New("missing required container_id parameter")
+ return
+ }
+ if fileID == "" {
+ err = errors.New("missing required file_id parameter")
+ return
+ }
+ path := fmt.Sprintf("containers/%s/files/%s/content", containerID, fileID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, nil, &res, opts...)
+ return
+}
@@ -10,7 +10,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -50,14 +50,13 @@ type CreateEmbeddingResponse struct {
Object constant.List `json:"object,required"`
// The usage information for the request.
Usage CreateEmbeddingResponseUsage `json:"usage,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Model resp.Field
- Object resp.Field
- Usage resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -74,12 +73,11 @@ type CreateEmbeddingResponseUsage struct {
PromptTokens int64 `json:"prompt_tokens,required"`
// The total number of tokens used by the request.
TotalTokens int64 `json:"total_tokens,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- PromptTokens resp.Field
- TotalTokens resp.Field
- ExtraFields map[string]resp.Field
+ PromptTokens respjson.Field
+ TotalTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -100,13 +98,12 @@ type Embedding struct {
Index int64 `json:"index,required"`
// The object type, which is always "embedding".
Object constant.Embedding `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Embedding resp.Field
- Index resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ Embedding respjson.Field
+ Index respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -129,11 +126,12 @@ type EmbeddingNewParams struct {
// Input text to embed, encoded as a string or array of tokens. To embed multiple
// inputs in a single request, pass an array of strings or array of token arrays.
// The input must not exceed the max input tokens for the model (8192 tokens for
- // `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048
+ // all embedding models), cannot be an empty string, and any array must be 2048
// dimensions or less.
// [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken)
- // for counting tokens. Some models may also impose a limit on total number of
- // tokens summed across inputs.
+ // for counting tokens. In addition to the per-input token limit, all embedding
+ // models enforce a maximum of 300,000 tokens summed across all inputs in a single
+ // request.
Input EmbeddingNewParamsInputUnion `json:"input,omitzero,required"`
// ID of the model to use. You can use the
// [List models](https://platform.openai.com/docs/api-reference/models/list) API to
@@ -156,14 +154,13 @@ type EmbeddingNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f EmbeddingNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r EmbeddingNewParams) MarshalJSON() (data []byte, err error) {
type shadow EmbeddingNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *EmbeddingNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -176,11 +173,11 @@ type EmbeddingNewParamsInputUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u EmbeddingNewParamsInputUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u EmbeddingNewParamsInputUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[EmbeddingNewParamsInputUnion](u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays)
+ return param.MarshalUnion(u, u.OfString, u.OfArrayOfStrings, u.OfArrayOfTokens, u.OfArrayOfTokenArrays)
+}
+func (u *EmbeddingNewParamsInputUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *EmbeddingNewParamsInputUnion) asAny() any {
@@ -19,7 +19,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -134,13 +134,12 @@ type FileDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.File `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -184,19 +183,18 @@ type FileObject struct {
//
// Deprecated: deprecated
StatusDetails string `json:"status_details"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Bytes resp.Field
- CreatedAt resp.Field
- Filename resp.Field
- Object resp.Field
- Purpose resp.Field
- Status resp.Field
- ExpiresAt resp.Field
- StatusDetails resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Bytes respjson.Field
+ CreatedAt respjson.Field
+ Filename respjson.Field
+ Object respjson.Field
+ Purpose respjson.Field
+ Status respjson.Field
+ ExpiresAt respjson.Field
+ StatusDetails respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -249,7 +247,7 @@ const (
type FileNewParams struct {
// The File object (not file name) to be uploaded.
- File io.Reader `json:"file,required" format:"binary"`
+ File io.Reader `json:"file,omitzero,required" format:"binary"`
// The intended purpose of the uploaded file. One of: - `assistants`: Used in the
// Assistants API - `batch`: Used in the Batch API - `fine-tune`: Used for
// fine-tuning - `vision`: Images used for vision fine-tuning - `user_data`:
@@ -260,14 +258,13 @@ type FileNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FileNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r FileNewParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -298,12 +295,8 @@ type FileListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FileListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [FileListParams]'s query parameters as `url.Values`.
-func (r FileListParams) URLQuery() (v url.Values) {
+func (r FileListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -13,8 +13,11 @@ import (
// automatically. You should not instantiate this service directly, and instead use
// the [NewFineTuningService] method instead.
type FineTuningService struct {
- Options []option.RequestOption
- Jobs FineTuningJobService
+ Options []option.RequestOption
+ Methods FineTuningMethodService
+ Jobs FineTuningJobService
+ Checkpoints FineTuningCheckpointService
+ Alpha FineTuningAlphaService
}
// NewFineTuningService generates a new service that applies the given options to
@@ -23,6 +26,9 @@ type FineTuningService struct {
func NewFineTuningService(opts ...option.RequestOption) (r FineTuningService) {
r = FineTuningService{}
r.Options = opts
+ r.Methods = NewFineTuningMethodService(opts...)
r.Jobs = NewFineTuningJobService(opts...)
+ r.Checkpoints = NewFineTuningCheckpointService(opts...)
+ r.Alpha = NewFineTuningAlphaService(opts...)
return
}
@@ -0,0 +1,28 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "github.com/openai/openai-go/option"
+)
+
+// FineTuningAlphaService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewFineTuningAlphaService] method instead.
+type FineTuningAlphaService struct {
+ Options []option.RequestOption
+ Graders FineTuningAlphaGraderService
+}
+
+// NewFineTuningAlphaService generates a new service that applies the given options
+// to each request. These options are applied after the parent client's options (if
+// there is one), and before any request-specific options.
+func NewFineTuningAlphaService(opts ...option.RequestOption) (r FineTuningAlphaService) {
+ r = FineTuningAlphaService{}
+ r.Options = opts
+ r.Graders = NewFineTuningAlphaGraderService(opts...)
+ return
+}
@@ -0,0 +1,672 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+)
+
+// FineTuningAlphaGraderService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewFineTuningAlphaGraderService] method instead.
+type FineTuningAlphaGraderService struct {
+ Options []option.RequestOption
+}
+
+// NewFineTuningAlphaGraderService generates a new service that applies the given
+// options to each request. These options are applied after the parent client's
+// options (if there is one), and before any request-specific options.
+func NewFineTuningAlphaGraderService(opts ...option.RequestOption) (r FineTuningAlphaGraderService) {
+ r = FineTuningAlphaGraderService{}
+ r.Options = opts
+ return
+}
+
+// Run a grader.
+func (r *FineTuningAlphaGraderService) Run(ctx context.Context, body FineTuningAlphaGraderRunParams, opts ...option.RequestOption) (res *FineTuningAlphaGraderRunResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ path := "fine_tuning/alpha/graders/run"
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...)
+ return
+}
+
+// Validate a grader.
+func (r *FineTuningAlphaGraderService) Validate(ctx context.Context, body FineTuningAlphaGraderValidateParams, opts ...option.RequestOption) (res *FineTuningAlphaGraderValidateResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ path := "fine_tuning/alpha/graders/validate"
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...)
+ return
+}
+
+type FineTuningAlphaGraderRunResponse struct {
+ Metadata FineTuningAlphaGraderRunResponseMetadata `json:"metadata,required"`
+ ModelGraderTokenUsagePerModel map[string]any `json:"model_grader_token_usage_per_model,required"`
+ Reward float64 `json:"reward,required"`
+ SubRewards map[string]any `json:"sub_rewards,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Metadata respjson.Field
+ ModelGraderTokenUsagePerModel respjson.Field
+ Reward respjson.Field
+ SubRewards respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningAlphaGraderRunResponse) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningAlphaGraderRunResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningAlphaGraderRunResponseMetadata struct {
+ Errors FineTuningAlphaGraderRunResponseMetadataErrors `json:"errors,required"`
+ ExecutionTime float64 `json:"execution_time,required"`
+ Name string `json:"name,required"`
+ SampledModelName string `json:"sampled_model_name,required"`
+ Scores map[string]any `json:"scores,required"`
+ TokenUsage int64 `json:"token_usage,required"`
+ Type string `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Errors respjson.Field
+ ExecutionTime respjson.Field
+ Name respjson.Field
+ SampledModelName respjson.Field
+ Scores respjson.Field
+ TokenUsage respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningAlphaGraderRunResponseMetadata) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningAlphaGraderRunResponseMetadata) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningAlphaGraderRunResponseMetadataErrors struct {
+ FormulaParseError bool `json:"formula_parse_error,required"`
+ InvalidVariableError bool `json:"invalid_variable_error,required"`
+ ModelGraderParseError bool `json:"model_grader_parse_error,required"`
+ ModelGraderRefusalError bool `json:"model_grader_refusal_error,required"`
+ ModelGraderServerError bool `json:"model_grader_server_error,required"`
+ ModelGraderServerErrorDetails string `json:"model_grader_server_error_details,required"`
+ OtherError bool `json:"other_error,required"`
+ PythonGraderRuntimeError bool `json:"python_grader_runtime_error,required"`
+ PythonGraderRuntimeErrorDetails string `json:"python_grader_runtime_error_details,required"`
+ PythonGraderServerError bool `json:"python_grader_server_error,required"`
+ PythonGraderServerErrorType string `json:"python_grader_server_error_type,required"`
+ SampleParseError bool `json:"sample_parse_error,required"`
+ TruncatedObservationError bool `json:"truncated_observation_error,required"`
+ UnresponsiveRewardError bool `json:"unresponsive_reward_error,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ FormulaParseError respjson.Field
+ InvalidVariableError respjson.Field
+ ModelGraderParseError respjson.Field
+ ModelGraderRefusalError respjson.Field
+ ModelGraderServerError respjson.Field
+ ModelGraderServerErrorDetails respjson.Field
+ OtherError respjson.Field
+ PythonGraderRuntimeError respjson.Field
+ PythonGraderRuntimeErrorDetails respjson.Field
+ PythonGraderServerError respjson.Field
+ PythonGraderServerErrorType respjson.Field
+ SampleParseError respjson.Field
+ TruncatedObservationError respjson.Field
+ UnresponsiveRewardError respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningAlphaGraderRunResponseMetadataErrors) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningAlphaGraderRunResponseMetadataErrors) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningAlphaGraderValidateResponse struct {
+ // The grader used for the fine-tuning job.
+ Grader FineTuningAlphaGraderValidateResponseGraderUnion `json:"grader"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Grader respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningAlphaGraderValidateResponse) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningAlphaGraderValidateResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// FineTuningAlphaGraderValidateResponseGraderUnion contains all possible
+// properties and values from [StringCheckGrader], [TextSimilarityGrader],
+// [PythonGrader], [ScoreModelGrader], [MultiGrader].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+type FineTuningAlphaGraderValidateResponseGraderUnion struct {
+ // This field is a union of [string], [string], [[]ScoreModelGraderInput]
+ Input FineTuningAlphaGraderValidateResponseGraderUnionInput `json:"input"`
+ Name string `json:"name"`
+ // This field is from variant [StringCheckGrader].
+ Operation StringCheckGraderOperation `json:"operation"`
+ Reference string `json:"reference"`
+ Type string `json:"type"`
+ // This field is from variant [TextSimilarityGrader].
+ EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"`
+ // This field is from variant [PythonGrader].
+ Source string `json:"source"`
+ // This field is from variant [PythonGrader].
+ ImageTag string `json:"image_tag"`
+ // This field is from variant [ScoreModelGrader].
+ Model string `json:"model"`
+ // This field is from variant [ScoreModelGrader].
+ Range []float64 `json:"range"`
+ // This field is from variant [ScoreModelGrader].
+ SamplingParams any `json:"sampling_params"`
+ // This field is from variant [MultiGrader].
+ CalculateOutput string `json:"calculate_output"`
+ // This field is from variant [MultiGrader].
+ Graders MultiGraderGradersUnion `json:"graders"`
+ JSON struct {
+ Input respjson.Field
+ Name respjson.Field
+ Operation respjson.Field
+ Reference respjson.Field
+ Type respjson.Field
+ EvaluationMetric respjson.Field
+ Source respjson.Field
+ ImageTag respjson.Field
+ Model respjson.Field
+ Range respjson.Field
+ SamplingParams respjson.Field
+ CalculateOutput respjson.Field
+ Graders respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsStringCheckGrader() (v StringCheckGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsPythonGrader() (v PythonGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsScoreModelGrader() (v ScoreModelGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) AsMultiGrader() (v MultiGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u FineTuningAlphaGraderValidateResponseGraderUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *FineTuningAlphaGraderValidateResponseGraderUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// FineTuningAlphaGraderValidateResponseGraderUnionInput is an implicit subunion of
+// [FineTuningAlphaGraderValidateResponseGraderUnion].
+// FineTuningAlphaGraderValidateResponseGraderUnionInput provides convenient access
+// to the sub-properties of the union.
+//
+// For type safety it is recommended to directly use a variant of the
+// [FineTuningAlphaGraderValidateResponseGraderUnion].
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString OfScoreModelGraderInputArray]
+type FineTuningAlphaGraderValidateResponseGraderUnionInput struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ // This field will be present if the value is a [[]ScoreModelGraderInput] instead
+ // of an object.
+ OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"`
+ JSON struct {
+ OfString respjson.Field
+ OfScoreModelGraderInputArray respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (r *FineTuningAlphaGraderValidateResponseGraderUnionInput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningAlphaGraderRunParams struct {
+ // The grader used for the fine-tuning job.
+ Grader FineTuningAlphaGraderRunParamsGraderUnion `json:"grader,omitzero,required"`
+ // The model sample to be evaluated. This value will be used to populate the
+ // `sample` namespace. See
+ // [the guide](https://platform.openai.com/docs/guides/graders) for more details.
+ // The `output_json` variable will be populated if the model sample is a valid JSON
+ // string.
+ ModelSample string `json:"model_sample,required"`
+ // The dataset item provided to the grader. This will be used to populate the
+ // `item` namespace. See
+ // [the guide](https://platform.openai.com/docs/guides/graders) for more details.
+ Item any `json:"item,omitzero"`
+ paramObj
+}
+
+func (r FineTuningAlphaGraderRunParams) MarshalJSON() (data []byte, err error) {
+ type shadow FineTuningAlphaGraderRunParams
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *FineTuningAlphaGraderRunParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type FineTuningAlphaGraderRunParamsGraderUnion struct {
+ OfStringCheck *StringCheckGraderParam `json:",omitzero,inline"`
+ OfTextSimilarity *TextSimilarityGraderParam `json:",omitzero,inline"`
+ OfPython *PythonGraderParam `json:",omitzero,inline"`
+ OfScoreModel *ScoreModelGraderParam `json:",omitzero,inline"`
+ OfMulti *MultiGraderParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u FineTuningAlphaGraderRunParamsGraderUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfStringCheck,
+ u.OfTextSimilarity,
+ u.OfPython,
+ u.OfScoreModel,
+ u.OfMulti)
+}
+func (u *FineTuningAlphaGraderRunParamsGraderUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *FineTuningAlphaGraderRunParamsGraderUnion) asAny() any {
+ if !param.IsOmitted(u.OfStringCheck) {
+ return u.OfStringCheck
+ } else if !param.IsOmitted(u.OfTextSimilarity) {
+ return u.OfTextSimilarity
+ } else if !param.IsOmitted(u.OfPython) {
+ return u.OfPython
+ } else if !param.IsOmitted(u.OfScoreModel) {
+ return u.OfScoreModel
+ } else if !param.IsOmitted(u.OfMulti) {
+ return u.OfMulti
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetOperation() *string {
+ if vt := u.OfStringCheck; vt != nil {
+ return (*string)(&vt.Operation)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetEvaluationMetric() *string {
+ if vt := u.OfTextSimilarity; vt != nil {
+ return (*string)(&vt.EvaluationMetric)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetSource() *string {
+ if vt := u.OfPython; vt != nil {
+ return &vt.Source
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetImageTag() *string {
+ if vt := u.OfPython; vt != nil && vt.ImageTag.Valid() {
+ return &vt.ImageTag.Value
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetModel() *string {
+ if vt := u.OfScoreModel; vt != nil {
+ return &vt.Model
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetRange() []float64 {
+ if vt := u.OfScoreModel; vt != nil {
+ return vt.Range
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetSamplingParams() *any {
+ if vt := u.OfScoreModel; vt != nil {
+ return &vt.SamplingParams
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetCalculateOutput() *string {
+ if vt := u.OfMulti; vt != nil {
+ return &vt.CalculateOutput
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetGraders() *MultiGraderGradersUnionParam {
+ if vt := u.OfMulti; vt != nil {
+ return &vt.Graders
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetName() *string {
+ if vt := u.OfStringCheck; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfTextSimilarity; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfPython; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfScoreModel; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfMulti; vt != nil {
+ return (*string)(&vt.Name)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetReference() *string {
+ if vt := u.OfStringCheck; vt != nil {
+ return (*string)(&vt.Reference)
+ } else if vt := u.OfTextSimilarity; vt != nil {
+ return (*string)(&vt.Reference)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetType() *string {
+ if vt := u.OfStringCheck; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfTextSimilarity; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfPython; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfScoreModel; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfMulti; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// Returns a subunion which exports methods to access subproperties
+//
+// Or use AsAny() to get the underlying value
+func (u FineTuningAlphaGraderRunParamsGraderUnion) GetInput() (res fineTuningAlphaGraderRunParamsGraderUnionInput) {
+ if vt := u.OfStringCheck; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfTextSimilarity; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfScoreModel; vt != nil {
+ res.any = &vt.Input
+ }
+ return
+}
+
+// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam]
+type fineTuningAlphaGraderRunParamsGraderUnionInput struct{ any }
+
+// Use the following switch statement to get the type of the union:
+//
+// switch u.AsAny().(type) {
+// case *string:
+// case *[]openai.ScoreModelGraderInputParam:
+// default:
+// fmt.Errorf("not present")
+// }
+func (u fineTuningAlphaGraderRunParamsGraderUnionInput) AsAny() any { return u.any }
+
+func init() {
+ apijson.RegisterUnion[FineTuningAlphaGraderRunParamsGraderUnion](
+ "type",
+ apijson.Discriminator[StringCheckGraderParam]("string_check"),
+ apijson.Discriminator[TextSimilarityGraderParam]("text_similarity"),
+ apijson.Discriminator[PythonGraderParam]("python"),
+ apijson.Discriminator[ScoreModelGraderParam]("score_model"),
+ apijson.Discriminator[MultiGraderParam]("multi"),
+ )
+}
+
+type FineTuningAlphaGraderValidateParams struct {
+ // The grader used for the fine-tuning job.
+ Grader FineTuningAlphaGraderValidateParamsGraderUnion `json:"grader,omitzero,required"`
+ paramObj
+}
+
+func (r FineTuningAlphaGraderValidateParams) MarshalJSON() (data []byte, err error) {
+ type shadow FineTuningAlphaGraderValidateParams
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *FineTuningAlphaGraderValidateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type FineTuningAlphaGraderValidateParamsGraderUnion struct {
+ OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"`
+ OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"`
+ OfPythonGrader *PythonGraderParam `json:",omitzero,inline"`
+ OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"`
+ OfMultiGrader *MultiGraderParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfStringCheckGrader,
+ u.OfTextSimilarityGrader,
+ u.OfPythonGrader,
+ u.OfScoreModelGrader,
+ u.OfMultiGrader)
+}
+func (u *FineTuningAlphaGraderValidateParamsGraderUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *FineTuningAlphaGraderValidateParamsGraderUnion) asAny() any {
+ if !param.IsOmitted(u.OfStringCheckGrader) {
+ return u.OfStringCheckGrader
+ } else if !param.IsOmitted(u.OfTextSimilarityGrader) {
+ return u.OfTextSimilarityGrader
+ } else if !param.IsOmitted(u.OfPythonGrader) {
+ return u.OfPythonGrader
+ } else if !param.IsOmitted(u.OfScoreModelGrader) {
+ return u.OfScoreModelGrader
+ } else if !param.IsOmitted(u.OfMultiGrader) {
+ return u.OfMultiGrader
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetOperation() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Operation)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetEvaluationMetric() *string {
+ if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.EvaluationMetric)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetSource() *string {
+ if vt := u.OfPythonGrader; vt != nil {
+ return &vt.Source
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetImageTag() *string {
+ if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() {
+ return &vt.ImageTag.Value
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetModel() *string {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return &vt.Model
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetRange() []float64 {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return vt.Range
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetSamplingParams() *any {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return &vt.SamplingParams
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetCalculateOutput() *string {
+ if vt := u.OfMultiGrader; vt != nil {
+ return &vt.CalculateOutput
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetGraders() *MultiGraderGradersUnionParam {
+ if vt := u.OfMultiGrader; vt != nil {
+ return &vt.Graders
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetName() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfMultiGrader; vt != nil {
+ return (*string)(&vt.Name)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetReference() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetType() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfMultiGrader; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// Returns a subunion which exports methods to access subproperties
+//
+// Or use AsAny() to get the underlying value
+func (u FineTuningAlphaGraderValidateParamsGraderUnion) GetInput() (res fineTuningAlphaGraderValidateParamsGraderUnionInput) {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ res.any = &vt.Input
+ }
+ return
+}
+
+// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam]
+type fineTuningAlphaGraderValidateParamsGraderUnionInput struct{ any }
+
+// Use the following switch statement to get the type of the union:
+//
+// switch u.AsAny().(type) {
+// case *string:
+// case *[]openai.ScoreModelGraderInputParam:
+// default:
+// fmt.Errorf("not present")
+// }
+func (u fineTuningAlphaGraderValidateParamsGraderUnionInput) AsAny() any { return u.any }
@@ -0,0 +1,28 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "github.com/openai/openai-go/option"
+)
+
+// FineTuningCheckpointService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewFineTuningCheckpointService] method instead.
+type FineTuningCheckpointService struct {
+ Options []option.RequestOption
+ Permissions FineTuningCheckpointPermissionService
+}
+
+// NewFineTuningCheckpointService generates a new service that applies the given
+// options to each request. These options are applied after the parent client's
+// options (if there is one), and before any request-specific options.
+func NewFineTuningCheckpointService(opts ...option.RequestOption) (r FineTuningCheckpointService) {
+ r = FineTuningCheckpointService{}
+ r.Options = opts
+ r.Permissions = NewFineTuningCheckpointPermissionService(opts...)
+ return
+}
@@ -0,0 +1,254 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "net/url"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/internal/apiquery"
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/pagination"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+ "github.com/openai/openai-go/shared/constant"
+)
+
+// FineTuningCheckpointPermissionService contains methods and other services that
+// help with interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewFineTuningCheckpointPermissionService] method instead.
+type FineTuningCheckpointPermissionService struct {
+ Options []option.RequestOption
+}
+
+// NewFineTuningCheckpointPermissionService generates a new service that applies
+// the given options to each request. These options are applied after the parent
+// client's options (if there is one), and before any request-specific options.
+func NewFineTuningCheckpointPermissionService(opts ...option.RequestOption) (r FineTuningCheckpointPermissionService) {
+ r = FineTuningCheckpointPermissionService{}
+ r.Options = opts
+ return
+}
+
+// **NOTE:** Calling this endpoint requires an [admin API key](../admin-api-keys).
+//
+// This enables organization owners to share fine-tuned models with other projects
+// in their organization.
+func (r *FineTuningCheckpointPermissionService) New(ctx context.Context, fineTunedModelCheckpoint string, body FineTuningCheckpointPermissionNewParams, opts ...option.RequestOption) (res *pagination.Page[FineTuningCheckpointPermissionNewResponse], err error) {
+ var raw *http.Response
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithResponseInto(&raw)}, opts...)
+ if fineTunedModelCheckpoint == "" {
+ err = errors.New("missing required fine_tuned_model_checkpoint parameter")
+ return
+ }
+ path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions", fineTunedModelCheckpoint)
+ cfg, err := requestconfig.NewRequestConfig(ctx, http.MethodPost, path, body, &res, opts...)
+ if err != nil {
+ return nil, err
+ }
+ err = cfg.Execute()
+ if err != nil {
+ return nil, err
+ }
+ res.SetPageConfig(cfg, raw)
+ return res, nil
+}
+
+// **NOTE:** Calling this endpoint requires an [admin API key](../admin-api-keys).
+//
+// This enables organization owners to share fine-tuned models with other projects
+// in their organization.
+func (r *FineTuningCheckpointPermissionService) NewAutoPaging(ctx context.Context, fineTunedModelCheckpoint string, body FineTuningCheckpointPermissionNewParams, opts ...option.RequestOption) *pagination.PageAutoPager[FineTuningCheckpointPermissionNewResponse] {
+ return pagination.NewPageAutoPager(r.New(ctx, fineTunedModelCheckpoint, body, opts...))
+}
+
+// **NOTE:** This endpoint requires an [admin API key](../admin-api-keys).
+//
+// Organization owners can use this endpoint to view all permissions for a
+// fine-tuned model checkpoint.
+func (r *FineTuningCheckpointPermissionService) Get(ctx context.Context, fineTunedModelCheckpoint string, query FineTuningCheckpointPermissionGetParams, opts ...option.RequestOption) (res *FineTuningCheckpointPermissionGetResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ if fineTunedModelCheckpoint == "" {
+ err = errors.New("missing required fine_tuned_model_checkpoint parameter")
+ return
+ }
+ path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions", fineTunedModelCheckpoint)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...)
+ return
+}
+
+// **NOTE:** This endpoint requires an [admin API key](../admin-api-keys).
+//
+// Organization owners can use this endpoint to delete a permission for a
+// fine-tuned model checkpoint.
+func (r *FineTuningCheckpointPermissionService) Delete(ctx context.Context, fineTunedModelCheckpoint string, permissionID string, opts ...option.RequestOption) (res *FineTuningCheckpointPermissionDeleteResponse, err error) {
+ opts = append(r.Options[:], opts...)
+ if fineTunedModelCheckpoint == "" {
+ err = errors.New("missing required fine_tuned_model_checkpoint parameter")
+ return
+ }
+ if permissionID == "" {
+ err = errors.New("missing required permission_id parameter")
+ return
+ }
+ path := fmt.Sprintf("fine_tuning/checkpoints/%s/permissions/%s", fineTunedModelCheckpoint, permissionID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, &res, opts...)
+ return
+}
+
+// The `checkpoint.permission` object represents a permission for a fine-tuned
+// model checkpoint.
+type FineTuningCheckpointPermissionNewResponse struct {
+ // The permission identifier, which can be referenced in the API endpoints.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) for when the permission was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // The object type, which is always "checkpoint.permission".
+ Object constant.CheckpointPermission `json:"object,required"`
+ // The project identifier that the permission is for.
+ ProjectID string `json:"project_id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ ProjectID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningCheckpointPermissionNewResponse) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningCheckpointPermissionNewResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningCheckpointPermissionGetResponse struct {
+ Data []FineTuningCheckpointPermissionGetResponseData `json:"data,required"`
+ HasMore bool `json:"has_more,required"`
+ Object constant.List `json:"object,required"`
+ FirstID string `json:"first_id,nullable"`
+ LastID string `json:"last_id,nullable"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Data respjson.Field
+ HasMore respjson.Field
+ Object respjson.Field
+ FirstID respjson.Field
+ LastID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningCheckpointPermissionGetResponse) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningCheckpointPermissionGetResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The `checkpoint.permission` object represents a permission for a fine-tuned
+// model checkpoint.
+type FineTuningCheckpointPermissionGetResponseData struct {
+ // The permission identifier, which can be referenced in the API endpoints.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) for when the permission was created.
+ CreatedAt int64 `json:"created_at,required"`
+ // The object type, which is always "checkpoint.permission".
+ Object constant.CheckpointPermission `json:"object,required"`
+ // The project identifier that the permission is for.
+ ProjectID string `json:"project_id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ ProjectID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningCheckpointPermissionGetResponseData) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningCheckpointPermissionGetResponseData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningCheckpointPermissionDeleteResponse struct {
+ // The ID of the fine-tuned model checkpoint permission that was deleted.
+ ID string `json:"id,required"`
+ // Whether the fine-tuned model checkpoint permission was successfully deleted.
+ Deleted bool `json:"deleted,required"`
+ // The object type, which is always "checkpoint.permission".
+ Object constant.CheckpointPermission `json:"object,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningCheckpointPermissionDeleteResponse) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningCheckpointPermissionDeleteResponse) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningCheckpointPermissionNewParams struct {
+ // The project identifiers to grant access to.
+ ProjectIDs []string `json:"project_ids,omitzero,required"`
+ paramObj
+}
+
+func (r FineTuningCheckpointPermissionNewParams) MarshalJSON() (data []byte, err error) {
+ type shadow FineTuningCheckpointPermissionNewParams
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *FineTuningCheckpointPermissionNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+type FineTuningCheckpointPermissionGetParams struct {
+ // Identifier for the last permission ID from the previous pagination request.
+ After param.Opt[string] `query:"after,omitzero" json:"-"`
+ // Number of permissions to retrieve.
+ Limit param.Opt[int64] `query:"limit,omitzero" json:"-"`
+ // The ID of the project to get permissions for.
+ ProjectID param.Opt[string] `query:"project_id,omitzero" json:"-"`
+ // The order in which to retrieve permissions.
+ //
+ // Any of "ascending", "descending".
+ Order FineTuningCheckpointPermissionGetParamsOrder `query:"order,omitzero" json:"-"`
+ paramObj
+}
+
+// URLQuery serializes [FineTuningCheckpointPermissionGetParams]'s query parameters
+// as `url.Values`.
+func (r FineTuningCheckpointPermissionGetParams) URLQuery() (v url.Values, err error) {
+ return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
+ ArrayFormat: apiquery.ArrayQueryFormatBrackets,
+ NestedFormat: apiquery.NestedQueryFormatBrackets,
+ })
+}
+
+// The order in which to retrieve permissions.
+type FineTuningCheckpointPermissionGetParamsOrder string
+
+const (
+ FineTuningCheckpointPermissionGetParamsOrderAscending FineTuningCheckpointPermissionGetParamsOrder = "ascending"
+ FineTuningCheckpointPermissionGetParamsOrderDescending FineTuningCheckpointPermissionGetParamsOrder = "descending"
+)
@@ -16,7 +16,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
)
@@ -48,7 +48,7 @@ func NewFineTuningJobService(opts ...option.RequestOption) (r FineTuningJobServi
// Response includes details of the enqueued job including job status and the name
// of the fine-tuned models once complete.
//
-// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/model-optimization)
func (r *FineTuningJobService) New(ctx context.Context, body FineTuningJobNewParams, opts ...option.RequestOption) (res *FineTuningJob, err error) {
opts = append(r.Options[:], opts...)
path := "fine_tuning/jobs"
@@ -58,7 +58,7 @@ func (r *FineTuningJobService) New(ctx context.Context, body FineTuningJobNewPar
// Get info about a fine-tuning job.
//
-// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/fine-tuning)
+// [Learn more about fine-tuning](https://platform.openai.com/docs/guides/model-optimization)
func (r *FineTuningJobService) Get(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) {
opts = append(r.Options[:], opts...)
if fineTuningJobID == "" {
@@ -132,6 +132,30 @@ func (r *FineTuningJobService) ListEventsAutoPaging(ctx context.Context, fineTun
return pagination.NewCursorPageAutoPager(r.ListEvents(ctx, fineTuningJobID, query, opts...))
}
+// Pause a fine-tune job.
+func (r *FineTuningJobService) Pause(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) {
+ opts = append(r.Options[:], opts...)
+ if fineTuningJobID == "" {
+ err = errors.New("missing required fine_tuning_job_id parameter")
+ return
+ }
+ path := fmt.Sprintf("fine_tuning/jobs/%s/pause", fineTuningJobID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...)
+ return
+}
+
+// Resume a fine-tune job.
+func (r *FineTuningJobService) Resume(ctx context.Context, fineTuningJobID string, opts ...option.RequestOption) (res *FineTuningJob, err error) {
+ opts = append(r.Options[:], opts...)
+ if fineTuningJobID == "" {
+ err = errors.New("missing required fine_tuning_job_id parameter")
+ return
+ }
+ path := fmt.Sprintf("fine_tuning/jobs/%s/resume", fineTuningJobID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...)
+ return
+}
+
// The `fine_tuning.job` object represents a fine-tuning job that has been created
// through the API.
type FineTuningJob struct {
@@ -193,29 +217,28 @@ type FineTuningJob struct {
Metadata shared.Metadata `json:"metadata,nullable"`
// The method used for fine-tuning.
Method FineTuningJobMethod `json:"method"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Error resp.Field
- FineTunedModel resp.Field
- FinishedAt resp.Field
- Hyperparameters resp.Field
- Model resp.Field
- Object resp.Field
- OrganizationID resp.Field
- ResultFiles resp.Field
- Seed resp.Field
- Status resp.Field
- TrainedTokens resp.Field
- TrainingFile resp.Field
- ValidationFile resp.Field
- EstimatedFinish resp.Field
- Integrations resp.Field
- Metadata resp.Field
- Method resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Error respjson.Field
+ FineTunedModel respjson.Field
+ FinishedAt respjson.Field
+ Hyperparameters respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ OrganizationID respjson.Field
+ ResultFiles respjson.Field
+ Seed respjson.Field
+ Status respjson.Field
+ TrainedTokens respjson.Field
+ TrainingFile respjson.Field
+ ValidationFile respjson.Field
+ EstimatedFinish respjson.Field
+ Integrations respjson.Field
+ Metadata respjson.Field
+ Method respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -236,13 +259,12 @@ type FineTuningJobError struct {
// The parameter that was invalid, usually `training_file` or `validation_file`.
// This field will be null if the failure was not parameter-specific.
Param string `json:"param,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- Param resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ Param respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -258,20 +280,19 @@ func (r *FineTuningJobError) UnmarshalJSON(data []byte) error {
type FineTuningJobHyperparameters struct {
// Number of examples in each batch. A larger batch size means that model
// parameters are updated less frequently, but with lower variance.
- BatchSize FineTuningJobHyperparametersBatchSizeUnion `json:"batch_size"`
+ BatchSize FineTuningJobHyperparametersBatchSizeUnion `json:"batch_size,nullable"`
// Scaling factor for the learning rate. A smaller learning rate may be useful to
// avoid overfitting.
LearningRateMultiplier FineTuningJobHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier"`
// The number of epochs to train the model for. An epoch refers to one full cycle
// through the training dataset.
NEpochs FineTuningJobHyperparametersNEpochsUnion `json:"n_epochs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- BatchSize resp.Field
- LearningRateMultiplier resp.Field
- NEpochs resp.Field
- ExtraFields map[string]resp.Field
+ BatchSize respjson.Field
+ LearningRateMultiplier respjson.Field
+ NEpochs respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -296,8 +317,8 @@ type FineTuningJobHyperparametersBatchSizeUnion struct {
// This field will be present if the value is a [int64] instead of an object.
OfInt int64 `json:",inline"`
JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
+ OfAuto respjson.Field
+ OfInt respjson.Field
raw string
} `json:"-"`
}
@@ -333,8 +354,8 @@ type FineTuningJobHyperparametersLearningRateMultiplierUnion struct {
// This field will be present if the value is a [float64] instead of an object.
OfFloat float64 `json:",inline"`
JSON struct {
- OfAuto resp.Field
- OfFloat resp.Field
+ OfAuto respjson.Field
+ OfFloat respjson.Field
raw string
} `json:"-"`
}
@@ -370,8 +391,8 @@ type FineTuningJobHyperparametersNEpochsUnion struct {
// This field will be present if the value is a [int64] instead of an object.
OfInt int64 `json:",inline"`
JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
+ OfAuto respjson.Field
+ OfInt respjson.Field
raw string
} `json:"-"`
}
@@ -408,22 +429,24 @@ const (
// The method used for fine-tuning.
type FineTuningJobMethod struct {
+ // The type of method. Is either `supervised`, `dpo`, or `reinforcement`.
+ //
+ // Any of "supervised", "dpo", "reinforcement".
+ Type string `json:"type,required"`
// Configuration for the DPO fine-tuning method.
- Dpo FineTuningJobMethodDpo `json:"dpo"`
+ Dpo DpoMethod `json:"dpo"`
+ // Configuration for the reinforcement fine-tuning method.
+ Reinforcement ReinforcementMethod `json:"reinforcement"`
// Configuration for the supervised fine-tuning method.
- Supervised FineTuningJobMethodSupervised `json:"supervised"`
- // The type of method. Is either `supervised` or `dpo`.
- //
- // Any of "supervised", "dpo".
- Type string `json:"type"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ Supervised SupervisedMethod `json:"supervised"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Dpo resp.Field
- Supervised resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Type respjson.Field
+ Dpo respjson.Field
+ Reinforcement respjson.Field
+ Supervised respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -433,369 +456,6 @@ func (r *FineTuningJobMethod) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
-// Configuration for the DPO fine-tuning method.
-type FineTuningJobMethodDpo struct {
- // The hyperparameters used for the fine-tuning job.
- Hyperparameters FineTuningJobMethodDpoHyperparameters `json:"hyperparameters"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- Hyperparameters resp.Field
- ExtraFields map[string]resp.Field
- raw string
- } `json:"-"`
-}
-
-// Returns the unmodified JSON received from the API
-func (r FineTuningJobMethodDpo) RawJSON() string { return r.JSON.raw }
-func (r *FineTuningJobMethodDpo) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// The hyperparameters used for the fine-tuning job.
-type FineTuningJobMethodDpoHyperparameters struct {
- // Number of examples in each batch. A larger batch size means that model
- // parameters are updated less frequently, but with lower variance.
- BatchSize FineTuningJobMethodDpoHyperparametersBatchSizeUnion `json:"batch_size"`
- // The beta value for the DPO method. A higher beta value will increase the weight
- // of the penalty between the policy and reference model.
- Beta FineTuningJobMethodDpoHyperparametersBetaUnion `json:"beta"`
- // Scaling factor for the learning rate. A smaller learning rate may be useful to
- // avoid overfitting.
- LearningRateMultiplier FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier"`
- // The number of epochs to train the model for. An epoch refers to one full cycle
- // through the training dataset.
- NEpochs FineTuningJobMethodDpoHyperparametersNEpochsUnion `json:"n_epochs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- BatchSize resp.Field
- Beta resp.Field
- LearningRateMultiplier resp.Field
- NEpochs resp.Field
- ExtraFields map[string]resp.Field
- raw string
- } `json:"-"`
-}
-
-// Returns the unmodified JSON received from the API
-func (r FineTuningJobMethodDpoHyperparameters) RawJSON() string { return r.JSON.raw }
-func (r *FineTuningJobMethodDpoHyperparameters) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodDpoHyperparametersBatchSizeUnion contains all possible
-// properties and values from [constant.Auto], [int64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfInt]
-type FineTuningJobMethodDpoHyperparametersBatchSizeUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [int64] instead of an object.
- OfInt int64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodDpoHyperparametersBatchSizeUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodDpoHyperparametersBatchSizeUnion) AsInt() (v int64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodDpoHyperparametersBatchSizeUnion) RawJSON() string { return u.JSON.raw }
-
-func (r *FineTuningJobMethodDpoHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodDpoHyperparametersBetaUnion contains all possible properties
-// and values from [constant.Auto], [float64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfFloat]
-type FineTuningJobMethodDpoHyperparametersBetaUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [float64] instead of an object.
- OfFloat float64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfFloat resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodDpoHyperparametersBetaUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodDpoHyperparametersBetaUnion) AsFloat() (v float64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodDpoHyperparametersBetaUnion) RawJSON() string { return u.JSON.raw }
-
-func (r *FineTuningJobMethodDpoHyperparametersBetaUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion contains all
-// possible properties and values from [constant.Auto], [float64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfFloat]
-type FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [float64] instead of an object.
- OfFloat float64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfFloat resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion) AsFloat() (v float64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion) RawJSON() string {
- return u.JSON.raw
-}
-
-func (r *FineTuningJobMethodDpoHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodDpoHyperparametersNEpochsUnion contains all possible
-// properties and values from [constant.Auto], [int64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfInt]
-type FineTuningJobMethodDpoHyperparametersNEpochsUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [int64] instead of an object.
- OfInt int64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodDpoHyperparametersNEpochsUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodDpoHyperparametersNEpochsUnion) AsInt() (v int64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodDpoHyperparametersNEpochsUnion) RawJSON() string { return u.JSON.raw }
-
-func (r *FineTuningJobMethodDpoHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// Configuration for the supervised fine-tuning method.
-type FineTuningJobMethodSupervised struct {
- // The hyperparameters used for the fine-tuning job.
- Hyperparameters FineTuningJobMethodSupervisedHyperparameters `json:"hyperparameters"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- Hyperparameters resp.Field
- ExtraFields map[string]resp.Field
- raw string
- } `json:"-"`
-}
-
-// Returns the unmodified JSON received from the API
-func (r FineTuningJobMethodSupervised) RawJSON() string { return r.JSON.raw }
-func (r *FineTuningJobMethodSupervised) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// The hyperparameters used for the fine-tuning job.
-type FineTuningJobMethodSupervisedHyperparameters struct {
- // Number of examples in each batch. A larger batch size means that model
- // parameters are updated less frequently, but with lower variance.
- BatchSize FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion `json:"batch_size"`
- // Scaling factor for the learning rate. A smaller learning rate may be useful to
- // avoid overfitting.
- LearningRateMultiplier FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier"`
- // The number of epochs to train the model for. An epoch refers to one full cycle
- // through the training dataset.
- NEpochs FineTuningJobMethodSupervisedHyperparametersNEpochsUnion `json:"n_epochs"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- BatchSize resp.Field
- LearningRateMultiplier resp.Field
- NEpochs resp.Field
- ExtraFields map[string]resp.Field
- raw string
- } `json:"-"`
-}
-
-// Returns the unmodified JSON received from the API
-func (r FineTuningJobMethodSupervisedHyperparameters) RawJSON() string { return r.JSON.raw }
-func (r *FineTuningJobMethodSupervisedHyperparameters) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion contains all possible
-// properties and values from [constant.Auto], [int64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfInt]
-type FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [int64] instead of an object.
- OfInt int64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion) AsInt() (v int64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion) RawJSON() string {
- return u.JSON.raw
-}
-
-func (r *FineTuningJobMethodSupervisedHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion contains
-// all possible properties and values from [constant.Auto], [float64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfFloat]
-type FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [float64] instead of an object.
- OfFloat float64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfFloat resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion) AsFloat() (v float64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion) RawJSON() string {
- return u.JSON.raw
-}
-
-func (r *FineTuningJobMethodSupervisedHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
-// FineTuningJobMethodSupervisedHyperparametersNEpochsUnion contains all possible
-// properties and values from [constant.Auto], [int64].
-//
-// Use the methods beginning with 'As' to cast the union to one of its variants.
-//
-// If the underlying value is not a json object, one of the following properties
-// will be valid: OfAuto OfInt]
-type FineTuningJobMethodSupervisedHyperparametersNEpochsUnion struct {
- // This field will be present if the value is a [constant.Auto] instead of an
- // object.
- OfAuto constant.Auto `json:",inline"`
- // This field will be present if the value is a [int64] instead of an object.
- OfInt int64 `json:",inline"`
- JSON struct {
- OfAuto resp.Field
- OfInt resp.Field
- raw string
- } `json:"-"`
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersNEpochsUnion) AsAuto() (v constant.Auto) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-func (u FineTuningJobMethodSupervisedHyperparametersNEpochsUnion) AsInt() (v int64) {
- apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
- return
-}
-
-// Returns the unmodified JSON received from the API
-func (u FineTuningJobMethodSupervisedHyperparametersNEpochsUnion) RawJSON() string { return u.JSON.raw }
-
-func (r *FineTuningJobMethodSupervisedHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
- return apijson.UnmarshalRoot(data, r)
-}
-
// Fine-tuning job event object
type FineTuningJobEvent struct {
// The object identifier.
@@ -811,22 +471,21 @@ type FineTuningJobEvent struct {
// The object type, which is always "fine_tuning.job.event".
Object constant.FineTuningJobEvent `json:"object,required"`
// The data associated with the event.
- Data interface{} `json:"data"`
+ Data any `json:"data"`
// The type of event.
//
// Any of "message", "metrics".
Type FineTuningJobEventType `json:"type"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Level resp.Field
- Message resp.Field
- Object resp.Field
- Data resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Level respjson.Field
+ Message respjson.Field
+ Object respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -872,14 +531,13 @@ type FineTuningJobWandbIntegration struct {
// through directly to WandB. Some default tags are generated by OpenAI:
// "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".
Tags []string `json:"tags"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Project resp.Field
- Entity resp.Field
- Name resp.Field
- Tags resp.Field
- ExtraFields map[string]resp.Field
+ Project respjson.Field
+ Entity respjson.Field
+ Name respjson.Field
+ Tags respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -898,12 +556,11 @@ type FineTuningJobWandbIntegrationObject struct {
// explicit display name for your run, add tags to your run, and set a default
// entity (team, username, etc) to be associated with your run.
Wandb FineTuningJobWandbIntegration `json:"wandb,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- Wandb resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ Wandb respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -917,7 +574,7 @@ func (r *FineTuningJobWandbIntegrationObject) UnmarshalJSON(data []byte) error {
type FineTuningJobNewParams struct {
// The name of the model to fine-tune. You can select one of the
// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned).
- Model string `json:"model,omitzero,required"`
+ Model FineTuningJobNewParamsModel `json:"model,omitzero,required"`
// The ID of an uploaded file that contains training data.
//
// See [upload file](https://platform.openai.com/docs/api-reference/files/create)
@@ -933,7 +590,8 @@ type FineTuningJobNewParams struct {
// [preference](https://platform.openai.com/docs/api-reference/fine-tuning/preference-input)
// format.
//
- // See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ // See the
+ // [fine-tuning guide](https://platform.openai.com/docs/guides/model-optimization)
// for more details.
TrainingFile string `json:"training_file,required"`
// The seed controls the reproducibility of the job. Passing in the same seed and
@@ -956,7 +614,8 @@ type FineTuningJobNewParams struct {
// Your dataset must be formatted as a JSONL file. You must upload your file with
// the purpose `fine-tune`.
//
- // See the [fine-tuning guide](https://platform.openai.com/docs/guides/fine-tuning)
+ // See the
+ // [fine-tuning guide](https://platform.openai.com/docs/guides/model-optimization)
// for more details.
ValidationFile param.Opt[string] `json:"validation_file,omitzero"`
// A list of integrations to enable for your fine-tuning job.
@@ -967,7 +626,7 @@ type FineTuningJobNewParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// The hyperparameters used for the fine-tuning job. This value is now deprecated
// in favor of `method`, and should be passed in under the `method` parameter.
Hyperparameters FineTuningJobNewParamsHyperparameters `json:"hyperparameters,omitzero"`
@@ -976,14 +635,24 @@ type FineTuningJobNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r FineTuningJobNewParams) MarshalJSON() (data []byte, err error) {
type shadow FineTuningJobNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FineTuningJobNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The name of the model to fine-tune. You can select one of the
+// [supported models](https://platform.openai.com/docs/guides/fine-tuning#which-models-can-be-fine-tuned).
+type FineTuningJobNewParamsModel string
+
+const (
+ FineTuningJobNewParamsModelBabbage002 FineTuningJobNewParamsModel = "babbage-002"
+ FineTuningJobNewParamsModelDavinci002 FineTuningJobNewParamsModel = "davinci-002"
+ FineTuningJobNewParamsModelGPT3_5Turbo FineTuningJobNewParamsModel = "gpt-3.5-turbo"
+ FineTuningJobNewParamsModelGPT4oMini FineTuningJobNewParamsModel = "gpt-4o-mini"
+)
// The hyperparameters used for the fine-tuning job. This value is now deprecated
// in favor of `method`, and should be passed in under the `method` parameter.
@@ -1002,34 +671,29 @@ type FineTuningJobNewParamsHyperparameters struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsHyperparameters) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r FineTuningJobNewParamsHyperparameters) MarshalJSON() (data []byte, err error) {
type shadow FineTuningJobNewParamsHyperparameters
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FineTuningJobNewParamsHyperparameters) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
// Use [param.IsOmitted] to confirm if a field is set.
type FineTuningJobNewParamsHyperparametersBatchSizeUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
+ // Construct this variant with constant.ValueOf[constant.Auto]()
OfAuto constant.Auto `json:",omitzero,inline"`
OfInt param.Opt[int64] `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsHyperparametersBatchSizeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u FineTuningJobNewParamsHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsHyperparametersBatchSizeUnion](u.OfAuto, u.OfInt)
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *FineTuningJobNewParamsHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *FineTuningJobNewParamsHyperparametersBatchSizeUnion) asAny() any {
@@ -1045,20 +709,17 @@ func (u *FineTuningJobNewParamsHyperparametersBatchSizeUnion) asAny() any {
//
// Use [param.IsOmitted] to confirm if a field is set.
type FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
+ // Construct this variant with constant.ValueOf[constant.Auto]()
OfAuto constant.Auto `json:",omitzero,inline"`
OfFloat param.Opt[float64] `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion](u.OfAuto, u.OfFloat)
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) asAny() any {
@@ -1074,20 +735,17 @@ func (u *FineTuningJobNewParamsHyperparametersLearningRateMultiplierUnion) asAny
//
// Use [param.IsOmitted] to confirm if a field is set.
type FineTuningJobNewParamsHyperparametersNEpochsUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
+ // Construct this variant with constant.ValueOf[constant.Auto]()
OfAuto constant.Auto `json:",omitzero,inline"`
OfInt param.Opt[int64] `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsHyperparametersNEpochsUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u FineTuningJobNewParamsHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsHyperparametersNEpochsUnion](u.OfAuto, u.OfInt)
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *FineTuningJobNewParamsHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *FineTuningJobNewParamsHyperparametersNEpochsUnion) asAny() any {
@@ -1114,15 +772,13 @@ type FineTuningJobNewParamsIntegration struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsIntegration) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r FineTuningJobNewParamsIntegration) MarshalJSON() (data []byte, err error) {
type shadow FineTuningJobNewParamsIntegration
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FineTuningJobNewParamsIntegration) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The settings for your integration with Weights and Biases. This payload
// specifies the project that metrics will be sent to. Optionally, you can set an
@@ -1147,329 +803,45 @@ type FineTuningJobNewParamsIntegrationWandb struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsIntegrationWandb) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r FineTuningJobNewParamsIntegrationWandb) MarshalJSON() (data []byte, err error) {
type shadow FineTuningJobNewParamsIntegrationWandb
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FineTuningJobNewParamsIntegrationWandb) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The method used for fine-tuning.
+//
+// The property Type is required.
type FineTuningJobNewParamsMethod struct {
+ // The type of method. Is either `supervised`, `dpo`, or `reinforcement`.
+ //
+ // Any of "supervised", "dpo", "reinforcement".
+ Type string `json:"type,omitzero,required"`
// Configuration for the DPO fine-tuning method.
- Dpo FineTuningJobNewParamsMethodDpo `json:"dpo,omitzero"`
+ Dpo DpoMethodParam `json:"dpo,omitzero"`
+ // Configuration for the reinforcement fine-tuning method.
+ Reinforcement ReinforcementMethodParam `json:"reinforcement,omitzero"`
// Configuration for the supervised fine-tuning method.
- Supervised FineTuningJobNewParamsMethodSupervised `json:"supervised,omitzero"`
- // The type of method. Is either `supervised` or `dpo`.
- //
- // Any of "supervised", "dpo".
- Type string `json:"type,omitzero"`
+ Supervised SupervisedMethodParam `json:"supervised,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsMethod) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r FineTuningJobNewParamsMethod) MarshalJSON() (data []byte, err error) {
type shadow FineTuningJobNewParamsMethod
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FineTuningJobNewParamsMethod) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[FineTuningJobNewParamsMethod](
- "Type", false, "supervised", "dpo",
+ "type", "supervised", "dpo", "reinforcement",
)
}
-// Configuration for the DPO fine-tuning method.
-type FineTuningJobNewParamsMethodDpo struct {
- // The hyperparameters used for the fine-tuning job.
- Hyperparameters FineTuningJobNewParamsMethodDpoHyperparameters `json:"hyperparameters,omitzero"`
- paramObj
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsMethodDpo) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-func (r FineTuningJobNewParamsMethodDpo) MarshalJSON() (data []byte, err error) {
- type shadow FineTuningJobNewParamsMethodDpo
- return param.MarshalObject(r, (*shadow)(&r))
-}
-
-// The hyperparameters used for the fine-tuning job.
-type FineTuningJobNewParamsMethodDpoHyperparameters struct {
- // Number of examples in each batch. A larger batch size means that model
- // parameters are updated less frequently, but with lower variance.
- BatchSize FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion `json:"batch_size,omitzero"`
- // The beta value for the DPO method. A higher beta value will increase the weight
- // of the penalty between the policy and reference model.
- Beta FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion `json:"beta,omitzero"`
- // Scaling factor for the learning rate. A smaller learning rate may be useful to
- // avoid overfitting.
- LearningRateMultiplier FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"`
- // The number of epochs to train the model for. An epoch refers to one full cycle
- // through the training dataset.
- NEpochs FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion `json:"n_epochs,omitzero"`
- paramObj
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsMethodDpoHyperparameters) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-func (r FineTuningJobNewParamsMethodDpoHyperparameters) MarshalJSON() (data []byte, err error) {
- type shadow FineTuningJobNewParamsMethodDpoHyperparameters
- return param.MarshalObject(r, (*shadow)(&r))
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfInt param.Opt[int64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion](u.OfAuto, u.OfInt)
-}
-
-func (u *FineTuningJobNewParamsMethodDpoHyperparametersBatchSizeUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfInt) {
- return &u.OfInt.Value
- }
- return nil
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfFloat param.Opt[float64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion](u.OfAuto, u.OfFloat)
-}
-
-func (u *FineTuningJobNewParamsMethodDpoHyperparametersBetaUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfFloat) {
- return &u.OfFloat.Value
- }
- return nil
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfFloat param.Opt[float64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion](u.OfAuto, u.OfFloat)
-}
-
-func (u *FineTuningJobNewParamsMethodDpoHyperparametersLearningRateMultiplierUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfFloat) {
- return &u.OfFloat.Value
- }
- return nil
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfInt param.Opt[int64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion](u.OfAuto, u.OfInt)
-}
-
-func (u *FineTuningJobNewParamsMethodDpoHyperparametersNEpochsUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfInt) {
- return &u.OfInt.Value
- }
- return nil
-}
-
-// Configuration for the supervised fine-tuning method.
-type FineTuningJobNewParamsMethodSupervised struct {
- // The hyperparameters used for the fine-tuning job.
- Hyperparameters FineTuningJobNewParamsMethodSupervisedHyperparameters `json:"hyperparameters,omitzero"`
- paramObj
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsMethodSupervised) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-func (r FineTuningJobNewParamsMethodSupervised) MarshalJSON() (data []byte, err error) {
- type shadow FineTuningJobNewParamsMethodSupervised
- return param.MarshalObject(r, (*shadow)(&r))
-}
-
-// The hyperparameters used for the fine-tuning job.
-type FineTuningJobNewParamsMethodSupervisedHyperparameters struct {
- // Number of examples in each batch. A larger batch size means that model
- // parameters are updated less frequently, but with lower variance.
- BatchSize FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion `json:"batch_size,omitzero"`
- // Scaling factor for the learning rate. A smaller learning rate may be useful to
- // avoid overfitting.
- LearningRateMultiplier FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"`
- // The number of epochs to train the model for. An epoch refers to one full cycle
- // through the training dataset.
- NEpochs FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion `json:"n_epochs,omitzero"`
- paramObj
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobNewParamsMethodSupervisedHyperparameters) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-func (r FineTuningJobNewParamsMethodSupervisedHyperparameters) MarshalJSON() (data []byte, err error) {
- type shadow FineTuningJobNewParamsMethodSupervisedHyperparameters
- return param.MarshalObject(r, (*shadow)(&r))
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfInt param.Opt[int64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion](u.OfAuto, u.OfInt)
-}
-
-func (u *FineTuningJobNewParamsMethodSupervisedHyperparametersBatchSizeUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfInt) {
- return &u.OfInt.Value
- }
- return nil
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfFloat param.Opt[float64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion](u.OfAuto, u.OfFloat)
-}
-
-func (u *FineTuningJobNewParamsMethodSupervisedHyperparametersLearningRateMultiplierUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfFloat) {
- return &u.OfFloat.Value
- }
- return nil
-}
-
-// Only one field can be non-zero.
-//
-// Use [param.IsOmitted] to confirm if a field is set.
-type FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion struct {
- // Construct this variant with constant.New[constant.Auto]() Check if union is this
- // variant with !param.IsOmitted(union.OfAuto)
- OfAuto constant.Auto `json:",omitzero,inline"`
- OfInt param.Opt[int64] `json:",omitzero,inline"`
- paramUnion
-}
-
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
-func (u FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion](u.OfAuto, u.OfInt)
-}
-
-func (u *FineTuningJobNewParamsMethodSupervisedHyperparametersNEpochsUnion) asAny() any {
- if !param.IsOmitted(u.OfAuto) {
- return &u.OfAuto
- } else if !param.IsOmitted(u.OfInt) {
- return &u.OfInt.Value
- }
- return nil
-}
-
type FineTuningJobListParams struct {
// Identifier for the last job from the previous pagination request.
After param.Opt[string] `query:"after,omitzero" json:"-"`
@@ -15,7 +15,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -82,17 +82,16 @@ type FineTuningJobCheckpoint struct {
Object constant.FineTuningJobCheckpoint `json:"object,required"`
// The step number that the checkpoint was created at.
StepNumber int64 `json:"step_number,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- FineTunedModelCheckpoint resp.Field
- FineTuningJobID resp.Field
- Metrics resp.Field
- Object resp.Field
- StepNumber resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ FineTunedModelCheckpoint respjson.Field
+ FineTuningJobID respjson.Field
+ Metrics respjson.Field
+ Object respjson.Field
+ StepNumber respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -112,17 +111,16 @@ type FineTuningJobCheckpointMetrics struct {
TrainMeanTokenAccuracy float64 `json:"train_mean_token_accuracy"`
ValidLoss float64 `json:"valid_loss"`
ValidMeanTokenAccuracy float64 `json:"valid_mean_token_accuracy"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- FullValidLoss resp.Field
- FullValidMeanTokenAccuracy resp.Field
- Step resp.Field
- TrainLoss resp.Field
- TrainMeanTokenAccuracy resp.Field
- ValidLoss resp.Field
- ValidMeanTokenAccuracy resp.Field
- ExtraFields map[string]resp.Field
+ FullValidLoss respjson.Field
+ FullValidMeanTokenAccuracy respjson.Field
+ Step respjson.Field
+ TrainLoss respjson.Field
+ TrainMeanTokenAccuracy respjson.Field
+ ValidLoss respjson.Field
+ ValidMeanTokenAccuracy respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -141,15 +139,9 @@ type FineTuningJobCheckpointListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FineTuningJobCheckpointListParams) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-
// URLQuery serializes [FineTuningJobCheckpointListParams]'s query parameters as
// `url.Values`.
-func (r FineTuningJobCheckpointListParams) URLQuery() (v url.Values) {
+func (r FineTuningJobCheckpointListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -0,0 +1,1487 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "encoding/json"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+ "github.com/openai/openai-go/shared/constant"
+)
+
+// FineTuningMethodService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewFineTuningMethodService] method instead.
+type FineTuningMethodService struct {
+ Options []option.RequestOption
+}
+
+// NewFineTuningMethodService generates a new service that applies the given
+// options to each request. These options are applied after the parent client's
+// options (if there is one), and before any request-specific options.
+func NewFineTuningMethodService(opts ...option.RequestOption) (r FineTuningMethodService) {
+ r = FineTuningMethodService{}
+ r.Options = opts
+ return
+}
+
+// The hyperparameters used for the DPO fine-tuning job.
+type DpoHyperparametersResp struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize DpoHyperparametersBatchSizeUnionResp `json:"batch_size"`
+ // The beta value for the DPO method. A higher beta value will increase the weight
+ // of the penalty between the policy and reference model.
+ Beta DpoHyperparametersBetaUnionResp `json:"beta"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier DpoHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs DpoHyperparametersNEpochsUnionResp `json:"n_epochs"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ BatchSize respjson.Field
+ Beta respjson.Field
+ LearningRateMultiplier respjson.Field
+ NEpochs respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r DpoHyperparametersResp) RawJSON() string { return r.JSON.raw }
+func (r *DpoHyperparametersResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this DpoHyperparametersResp to a DpoHyperparameters.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// DpoHyperparameters.Overrides()
+func (r DpoHyperparametersResp) ToParam() DpoHyperparameters {
+ return param.Override[DpoHyperparameters](json.RawMessage(r.RawJSON()))
+}
+
+// DpoHyperparametersBatchSizeUnionResp contains all possible properties and values
+// from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type DpoHyperparametersBatchSizeUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u DpoHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u DpoHyperparametersBatchSizeUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u DpoHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *DpoHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// DpoHyperparametersBetaUnionResp contains all possible properties and values from
+// [constant.Auto], [float64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfFloat]
+type DpoHyperparametersBetaUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [float64] instead of an object.
+ OfFloat float64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfFloat respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u DpoHyperparametersBetaUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u DpoHyperparametersBetaUnionResp) AsFloat() (v float64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u DpoHyperparametersBetaUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *DpoHyperparametersBetaUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// DpoHyperparametersLearningRateMultiplierUnionResp contains all possible
+// properties and values from [constant.Auto], [float64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfFloat]
+type DpoHyperparametersLearningRateMultiplierUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [float64] instead of an object.
+ OfFloat float64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfFloat respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u DpoHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u DpoHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u DpoHyperparametersLearningRateMultiplierUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *DpoHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// DpoHyperparametersNEpochsUnionResp contains all possible properties and values
+// from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type DpoHyperparametersNEpochsUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u DpoHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u DpoHyperparametersNEpochsUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u DpoHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *DpoHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The hyperparameters used for the DPO fine-tuning job.
+type DpoHyperparameters struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize DpoHyperparametersBatchSizeUnion `json:"batch_size,omitzero"`
+ // The beta value for the DPO method. A higher beta value will increase the weight
+ // of the penalty between the policy and reference model.
+ Beta DpoHyperparametersBetaUnion `json:"beta,omitzero"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier DpoHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs DpoHyperparametersNEpochsUnion `json:"n_epochs,omitzero"`
+ paramObj
+}
+
+func (r DpoHyperparameters) MarshalJSON() (data []byte, err error) {
+ type shadow DpoHyperparameters
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *DpoHyperparameters) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type DpoHyperparametersBatchSizeUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u DpoHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *DpoHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *DpoHyperparametersBatchSizeUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type DpoHyperparametersBetaUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfFloat param.Opt[float64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u DpoHyperparametersBetaUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *DpoHyperparametersBetaUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *DpoHyperparametersBetaUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfFloat) {
+ return &u.OfFloat.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type DpoHyperparametersLearningRateMultiplierUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfFloat param.Opt[float64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u DpoHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *DpoHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *DpoHyperparametersLearningRateMultiplierUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfFloat) {
+ return &u.OfFloat.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type DpoHyperparametersNEpochsUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u DpoHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *DpoHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *DpoHyperparametersNEpochsUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Configuration for the DPO fine-tuning method.
+type DpoMethod struct {
+ // The hyperparameters used for the DPO fine-tuning job.
+ Hyperparameters DpoHyperparametersResp `json:"hyperparameters"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Hyperparameters respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r DpoMethod) RawJSON() string { return r.JSON.raw }
+func (r *DpoMethod) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this DpoMethod to a DpoMethodParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// DpoMethodParam.Overrides()
+func (r DpoMethod) ToParam() DpoMethodParam {
+ return param.Override[DpoMethodParam](json.RawMessage(r.RawJSON()))
+}
+
+// Configuration for the DPO fine-tuning method.
+type DpoMethodParam struct {
+ // The hyperparameters used for the DPO fine-tuning job.
+ Hyperparameters DpoHyperparameters `json:"hyperparameters,omitzero"`
+ paramObj
+}
+
+func (r DpoMethodParam) MarshalJSON() (data []byte, err error) {
+ type shadow DpoMethodParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *DpoMethodParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The hyperparameters used for the reinforcement fine-tuning job.
+type ReinforcementHyperparametersResp struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize ReinforcementHyperparametersBatchSizeUnionResp `json:"batch_size"`
+ // Multiplier on amount of compute used for exploring search space during training.
+ ComputeMultiplier ReinforcementHyperparametersComputeMultiplierUnionResp `json:"compute_multiplier"`
+ // The number of training steps between evaluation runs.
+ EvalInterval ReinforcementHyperparametersEvalIntervalUnionResp `json:"eval_interval"`
+ // Number of evaluation samples to generate per training step.
+ EvalSamples ReinforcementHyperparametersEvalSamplesUnionResp `json:"eval_samples"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier ReinforcementHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs ReinforcementHyperparametersNEpochsUnionResp `json:"n_epochs"`
+ // Level of reasoning effort.
+ //
+ // Any of "default", "low", "medium", "high".
+ ReasoningEffort ReinforcementHyperparametersReasoningEffort `json:"reasoning_effort"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ BatchSize respjson.Field
+ ComputeMultiplier respjson.Field
+ EvalInterval respjson.Field
+ EvalSamples respjson.Field
+ LearningRateMultiplier respjson.Field
+ NEpochs respjson.Field
+ ReasoningEffort respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ReinforcementHyperparametersResp) RawJSON() string { return r.JSON.raw }
+func (r *ReinforcementHyperparametersResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this ReinforcementHyperparametersResp to a
+// ReinforcementHyperparameters.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// ReinforcementHyperparameters.Overrides()
+func (r ReinforcementHyperparametersResp) ToParam() ReinforcementHyperparameters {
+ return param.Override[ReinforcementHyperparameters](json.RawMessage(r.RawJSON()))
+}
+
+// ReinforcementHyperparametersBatchSizeUnionResp contains all possible properties
+// and values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type ReinforcementHyperparametersBatchSizeUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersBatchSizeUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementHyperparametersComputeMultiplierUnionResp contains all possible
+// properties and values from [constant.Auto], [float64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfFloat]
+type ReinforcementHyperparametersComputeMultiplierUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [float64] instead of an object.
+ OfFloat float64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfFloat respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersComputeMultiplierUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersComputeMultiplierUnionResp) AsFloat() (v float64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersComputeMultiplierUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementHyperparametersComputeMultiplierUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementHyperparametersEvalIntervalUnionResp contains all possible
+// properties and values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type ReinforcementHyperparametersEvalIntervalUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersEvalIntervalUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersEvalIntervalUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersEvalIntervalUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementHyperparametersEvalIntervalUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementHyperparametersEvalSamplesUnionResp contains all possible
+// properties and values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type ReinforcementHyperparametersEvalSamplesUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersEvalSamplesUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersEvalSamplesUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersEvalSamplesUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementHyperparametersEvalSamplesUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementHyperparametersLearningRateMultiplierUnionResp contains all
+// possible properties and values from [constant.Auto], [float64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfFloat]
+type ReinforcementHyperparametersLearningRateMultiplierUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [float64] instead of an object.
+ OfFloat float64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfFloat respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersLearningRateMultiplierUnionResp) RawJSON() string {
+ return u.JSON.raw
+}
+
+func (r *ReinforcementHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementHyperparametersNEpochsUnionResp contains all possible properties
+// and values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type ReinforcementHyperparametersNEpochsUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementHyperparametersNEpochsUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Level of reasoning effort.
+type ReinforcementHyperparametersReasoningEffort string
+
+const (
+ ReinforcementHyperparametersReasoningEffortDefault ReinforcementHyperparametersReasoningEffort = "default"
+ ReinforcementHyperparametersReasoningEffortLow ReinforcementHyperparametersReasoningEffort = "low"
+ ReinforcementHyperparametersReasoningEffortMedium ReinforcementHyperparametersReasoningEffort = "medium"
+ ReinforcementHyperparametersReasoningEffortHigh ReinforcementHyperparametersReasoningEffort = "high"
+)
+
+// The hyperparameters used for the reinforcement fine-tuning job.
+type ReinforcementHyperparameters struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize ReinforcementHyperparametersBatchSizeUnion `json:"batch_size,omitzero"`
+ // Multiplier on amount of compute used for exploring search space during training.
+ ComputeMultiplier ReinforcementHyperparametersComputeMultiplierUnion `json:"compute_multiplier,omitzero"`
+ // The number of training steps between evaluation runs.
+ EvalInterval ReinforcementHyperparametersEvalIntervalUnion `json:"eval_interval,omitzero"`
+ // Number of evaluation samples to generate per training step.
+ EvalSamples ReinforcementHyperparametersEvalSamplesUnion `json:"eval_samples,omitzero"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier ReinforcementHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs ReinforcementHyperparametersNEpochsUnion `json:"n_epochs,omitzero"`
+ // Level of reasoning effort.
+ //
+ // Any of "default", "low", "medium", "high".
+ ReasoningEffort ReinforcementHyperparametersReasoningEffort `json:"reasoning_effort,omitzero"`
+ paramObj
+}
+
+func (r ReinforcementHyperparameters) MarshalJSON() (data []byte, err error) {
+ type shadow ReinforcementHyperparameters
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ReinforcementHyperparameters) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersBatchSizeUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *ReinforcementHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersBatchSizeUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersComputeMultiplierUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfFloat param.Opt[float64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersComputeMultiplierUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *ReinforcementHyperparametersComputeMultiplierUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersComputeMultiplierUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfFloat) {
+ return &u.OfFloat.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersEvalIntervalUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersEvalIntervalUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *ReinforcementHyperparametersEvalIntervalUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersEvalIntervalUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersEvalSamplesUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersEvalSamplesUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *ReinforcementHyperparametersEvalSamplesUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersEvalSamplesUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersLearningRateMultiplierUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfFloat param.Opt[float64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *ReinforcementHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersLearningRateMultiplierUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfFloat) {
+ return &u.OfFloat.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementHyperparametersNEpochsUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *ReinforcementHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementHyperparametersNEpochsUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Configuration for the reinforcement fine-tuning method.
+type ReinforcementMethod struct {
+ // The grader used for the fine-tuning job.
+ Grader ReinforcementMethodGraderUnion `json:"grader,required"`
+ // The hyperparameters used for the reinforcement fine-tuning job.
+ Hyperparameters ReinforcementHyperparametersResp `json:"hyperparameters"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Grader respjson.Field
+ Hyperparameters respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ReinforcementMethod) RawJSON() string { return r.JSON.raw }
+func (r *ReinforcementMethod) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this ReinforcementMethod to a ReinforcementMethodParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// ReinforcementMethodParam.Overrides()
+func (r ReinforcementMethod) ToParam() ReinforcementMethodParam {
+ return param.Override[ReinforcementMethodParam](json.RawMessage(r.RawJSON()))
+}
+
+// ReinforcementMethodGraderUnion contains all possible properties and values from
+// [StringCheckGrader], [TextSimilarityGrader], [PythonGrader], [ScoreModelGrader],
+// [MultiGrader].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+type ReinforcementMethodGraderUnion struct {
+ // This field is a union of [string], [string], [[]ScoreModelGraderInput]
+ Input ReinforcementMethodGraderUnionInput `json:"input"`
+ Name string `json:"name"`
+ // This field is from variant [StringCheckGrader].
+ Operation StringCheckGraderOperation `json:"operation"`
+ Reference string `json:"reference"`
+ Type string `json:"type"`
+ // This field is from variant [TextSimilarityGrader].
+ EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"`
+ // This field is from variant [PythonGrader].
+ Source string `json:"source"`
+ // This field is from variant [PythonGrader].
+ ImageTag string `json:"image_tag"`
+ // This field is from variant [ScoreModelGrader].
+ Model string `json:"model"`
+ // This field is from variant [ScoreModelGrader].
+ Range []float64 `json:"range"`
+ // This field is from variant [ScoreModelGrader].
+ SamplingParams any `json:"sampling_params"`
+ // This field is from variant [MultiGrader].
+ CalculateOutput string `json:"calculate_output"`
+ // This field is from variant [MultiGrader].
+ Graders MultiGraderGradersUnion `json:"graders"`
+ JSON struct {
+ Input respjson.Field
+ Name respjson.Field
+ Operation respjson.Field
+ Reference respjson.Field
+ Type respjson.Field
+ EvaluationMetric respjson.Field
+ Source respjson.Field
+ ImageTag respjson.Field
+ Model respjson.Field
+ Range respjson.Field
+ SamplingParams respjson.Field
+ CalculateOutput respjson.Field
+ Graders respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ReinforcementMethodGraderUnion) AsStringCheckGrader() (v StringCheckGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementMethodGraderUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementMethodGraderUnion) AsPythonGrader() (v PythonGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementMethodGraderUnion) AsScoreModelGrader() (v ScoreModelGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ReinforcementMethodGraderUnion) AsMultiGrader() (v MultiGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ReinforcementMethodGraderUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *ReinforcementMethodGraderUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ReinforcementMethodGraderUnionInput is an implicit subunion of
+// [ReinforcementMethodGraderUnion]. ReinforcementMethodGraderUnionInput provides
+// convenient access to the sub-properties of the union.
+//
+// For type safety it is recommended to directly use a variant of the
+// [ReinforcementMethodGraderUnion].
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString OfScoreModelGraderInputArray]
+type ReinforcementMethodGraderUnionInput struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ // This field will be present if the value is a [[]ScoreModelGraderInput] instead
+ // of an object.
+ OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"`
+ JSON struct {
+ OfString respjson.Field
+ OfScoreModelGraderInputArray respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (r *ReinforcementMethodGraderUnionInput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Configuration for the reinforcement fine-tuning method.
+//
+// The property Grader is required.
+type ReinforcementMethodParam struct {
+ // The grader used for the fine-tuning job.
+ Grader ReinforcementMethodGraderUnionParam `json:"grader,omitzero,required"`
+ // The hyperparameters used for the reinforcement fine-tuning job.
+ Hyperparameters ReinforcementHyperparameters `json:"hyperparameters,omitzero"`
+ paramObj
+}
+
+func (r ReinforcementMethodParam) MarshalJSON() (data []byte, err error) {
+ type shadow ReinforcementMethodParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ReinforcementMethodParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ReinforcementMethodGraderUnionParam struct {
+ OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"`
+ OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"`
+ OfPythonGrader *PythonGraderParam `json:",omitzero,inline"`
+ OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"`
+ OfMultiGrader *MultiGraderParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ReinforcementMethodGraderUnionParam) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfStringCheckGrader,
+ u.OfTextSimilarityGrader,
+ u.OfPythonGrader,
+ u.OfScoreModelGrader,
+ u.OfMultiGrader)
+}
+func (u *ReinforcementMethodGraderUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ReinforcementMethodGraderUnionParam) asAny() any {
+ if !param.IsOmitted(u.OfStringCheckGrader) {
+ return u.OfStringCheckGrader
+ } else if !param.IsOmitted(u.OfTextSimilarityGrader) {
+ return u.OfTextSimilarityGrader
+ } else if !param.IsOmitted(u.OfPythonGrader) {
+ return u.OfPythonGrader
+ } else if !param.IsOmitted(u.OfScoreModelGrader) {
+ return u.OfScoreModelGrader
+ } else if !param.IsOmitted(u.OfMultiGrader) {
+ return u.OfMultiGrader
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetOperation() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Operation)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetEvaluationMetric() *string {
+ if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.EvaluationMetric)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetSource() *string {
+ if vt := u.OfPythonGrader; vt != nil {
+ return &vt.Source
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetImageTag() *string {
+ if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() {
+ return &vt.ImageTag.Value
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetModel() *string {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return &vt.Model
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetRange() []float64 {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return vt.Range
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetSamplingParams() *any {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return &vt.SamplingParams
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetCalculateOutput() *string {
+ if vt := u.OfMultiGrader; vt != nil {
+ return &vt.CalculateOutput
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetGraders() *MultiGraderGradersUnionParam {
+ if vt := u.OfMultiGrader; vt != nil {
+ return &vt.Graders
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetName() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfMultiGrader; vt != nil {
+ return (*string)(&vt.Name)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetReference() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ReinforcementMethodGraderUnionParam) GetType() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfMultiGrader; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// Returns a subunion which exports methods to access subproperties
+//
+// Or use AsAny() to get the underlying value
+func (u ReinforcementMethodGraderUnionParam) GetInput() (res reinforcementMethodGraderUnionParamInput) {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ res.any = &vt.Input
+ }
+ return
+}
+
+// Can have the runtime types [*string], [\*[]ScoreModelGraderInputParam]
+type reinforcementMethodGraderUnionParamInput struct{ any }
+
+// Use the following switch statement to get the type of the union:
+//
+// switch u.AsAny().(type) {
+// case *string:
+// case *[]openai.ScoreModelGraderInputParam:
+// default:
+// fmt.Errorf("not present")
+// }
+func (u reinforcementMethodGraderUnionParamInput) AsAny() any { return u.any }
+
+// The hyperparameters used for the fine-tuning job.
+type SupervisedHyperparametersResp struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize SupervisedHyperparametersBatchSizeUnionResp `json:"batch_size"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier SupervisedHyperparametersLearningRateMultiplierUnionResp `json:"learning_rate_multiplier"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs SupervisedHyperparametersNEpochsUnionResp `json:"n_epochs"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ BatchSize respjson.Field
+ LearningRateMultiplier respjson.Field
+ NEpochs respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r SupervisedHyperparametersResp) RawJSON() string { return r.JSON.raw }
+func (r *SupervisedHyperparametersResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this SupervisedHyperparametersResp to a
+// SupervisedHyperparameters.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// SupervisedHyperparameters.Overrides()
+func (r SupervisedHyperparametersResp) ToParam() SupervisedHyperparameters {
+ return param.Override[SupervisedHyperparameters](json.RawMessage(r.RawJSON()))
+}
+
+// SupervisedHyperparametersBatchSizeUnionResp contains all possible properties and
+// values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type SupervisedHyperparametersBatchSizeUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u SupervisedHyperparametersBatchSizeUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u SupervisedHyperparametersBatchSizeUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u SupervisedHyperparametersBatchSizeUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *SupervisedHyperparametersBatchSizeUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// SupervisedHyperparametersLearningRateMultiplierUnionResp contains all possible
+// properties and values from [constant.Auto], [float64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfFloat]
+type SupervisedHyperparametersLearningRateMultiplierUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [float64] instead of an object.
+ OfFloat float64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfFloat respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) AsFloat() (v float64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u SupervisedHyperparametersLearningRateMultiplierUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *SupervisedHyperparametersLearningRateMultiplierUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// SupervisedHyperparametersNEpochsUnionResp contains all possible properties and
+// values from [constant.Auto], [int64].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfAuto OfInt]
+type SupervisedHyperparametersNEpochsUnionResp struct {
+ // This field will be present if the value is a [constant.Auto] instead of an
+ // object.
+ OfAuto constant.Auto `json:",inline"`
+ // This field will be present if the value is a [int64] instead of an object.
+ OfInt int64 `json:",inline"`
+ JSON struct {
+ OfAuto respjson.Field
+ OfInt respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u SupervisedHyperparametersNEpochsUnionResp) AsAuto() (v constant.Auto) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u SupervisedHyperparametersNEpochsUnionResp) AsInt() (v int64) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u SupervisedHyperparametersNEpochsUnionResp) RawJSON() string { return u.JSON.raw }
+
+func (r *SupervisedHyperparametersNEpochsUnionResp) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The hyperparameters used for the fine-tuning job.
+type SupervisedHyperparameters struct {
+ // Number of examples in each batch. A larger batch size means that model
+ // parameters are updated less frequently, but with lower variance.
+ BatchSize SupervisedHyperparametersBatchSizeUnion `json:"batch_size,omitzero"`
+ // Scaling factor for the learning rate. A smaller learning rate may be useful to
+ // avoid overfitting.
+ LearningRateMultiplier SupervisedHyperparametersLearningRateMultiplierUnion `json:"learning_rate_multiplier,omitzero"`
+ // The number of epochs to train the model for. An epoch refers to one full cycle
+ // through the training dataset.
+ NEpochs SupervisedHyperparametersNEpochsUnion `json:"n_epochs,omitzero"`
+ paramObj
+}
+
+func (r SupervisedHyperparameters) MarshalJSON() (data []byte, err error) {
+ type shadow SupervisedHyperparameters
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *SupervisedHyperparameters) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type SupervisedHyperparametersBatchSizeUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u SupervisedHyperparametersBatchSizeUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *SupervisedHyperparametersBatchSizeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *SupervisedHyperparametersBatchSizeUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type SupervisedHyperparametersLearningRateMultiplierUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfFloat param.Opt[float64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u SupervisedHyperparametersLearningRateMultiplierUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfFloat)
+}
+func (u *SupervisedHyperparametersLearningRateMultiplierUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *SupervisedHyperparametersLearningRateMultiplierUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfFloat) {
+ return &u.OfFloat.Value
+ }
+ return nil
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type SupervisedHyperparametersNEpochsUnion struct {
+ // Construct this variant with constant.ValueOf[constant.Auto]()
+ OfAuto constant.Auto `json:",omitzero,inline"`
+ OfInt param.Opt[int64] `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u SupervisedHyperparametersNEpochsUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfAuto, u.OfInt)
+}
+func (u *SupervisedHyperparametersNEpochsUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *SupervisedHyperparametersNEpochsUnion) asAny() any {
+ if !param.IsOmitted(u.OfAuto) {
+ return &u.OfAuto
+ } else if !param.IsOmitted(u.OfInt) {
+ return &u.OfInt.Value
+ }
+ return nil
+}
+
+// Configuration for the supervised fine-tuning method.
+type SupervisedMethod struct {
+ // The hyperparameters used for the fine-tuning job.
+ Hyperparameters SupervisedHyperparametersResp `json:"hyperparameters"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Hyperparameters respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r SupervisedMethod) RawJSON() string { return r.JSON.raw }
+func (r *SupervisedMethod) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this SupervisedMethod to a SupervisedMethodParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// SupervisedMethodParam.Overrides()
+func (r SupervisedMethod) ToParam() SupervisedMethodParam {
+ return param.Override[SupervisedMethodParam](json.RawMessage(r.RawJSON()))
+}
+
+// Configuration for the supervised fine-tuning method.
+type SupervisedMethodParam struct {
+ // The hyperparameters used for the fine-tuning job.
+ Hyperparameters SupervisedHyperparameters `json:"hyperparameters,omitzero"`
+ paramObj
+}
+
+func (r SupervisedMethodParam) MarshalJSON() (data []byte, err error) {
+ type shadow SupervisedMethodParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *SupervisedMethodParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
@@ -0,0 +1,28 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "github.com/openai/openai-go/option"
+)
+
+// GraderService contains methods and other services that help with interacting
+// with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewGraderService] method instead.
+type GraderService struct {
+ Options []option.RequestOption
+ GraderModels GraderGraderModelService
+}
+
+// NewGraderService generates a new service that applies the given options to each
+// request. These options are applied after the parent client's options (if there
+// is one), and before any request-specific options.
+func NewGraderService(opts ...option.RequestOption) (r GraderService) {
+ r = GraderService{}
+ r.Options = opts
+ r.GraderModels = NewGraderGraderModelService(opts...)
+ return
+}
@@ -0,0 +1,1179 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package openai
+
+import (
+ "encoding/json"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+ "github.com/openai/openai-go/responses"
+ "github.com/openai/openai-go/shared/constant"
+)
+
+// GraderGraderModelService contains methods and other services that help with
+// interacting with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewGraderGraderModelService] method instead.
+type GraderGraderModelService struct {
+ Options []option.RequestOption
+}
+
+// NewGraderGraderModelService generates a new service that applies the given
+// options to each request. These options are applied after the parent client's
+// options (if there is one), and before any request-specific options.
+func NewGraderGraderModelService(opts ...option.RequestOption) (r GraderGraderModelService) {
+ r = GraderGraderModelService{}
+ r.Options = opts
+ return
+}
+
+// A LabelModelGrader object which uses a model to assign labels to each item in
+// the evaluation.
+type LabelModelGrader struct {
+ Input []LabelModelGraderInput `json:"input,required"`
+ // The labels to assign to each item in the evaluation.
+ Labels []string `json:"labels,required"`
+ // The model to use for the evaluation. Must support structured outputs.
+ Model string `json:"model,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The labels that indicate a passing result. Must be a subset of labels.
+ PassingLabels []string `json:"passing_labels,required"`
+ // The object type, which is always `label_model`.
+ Type constant.LabelModel `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Input respjson.Field
+ Labels respjson.Field
+ Model respjson.Field
+ Name respjson.Field
+ PassingLabels respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r LabelModelGrader) RawJSON() string { return r.JSON.raw }
+func (r *LabelModelGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this LabelModelGrader to a LabelModelGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// LabelModelGraderParam.Overrides()
+func (r LabelModelGrader) ToParam() LabelModelGraderParam {
+ return param.Override[LabelModelGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// A message input to the model with a role indicating instruction following
+// hierarchy. Instructions given with the `developer` or `system` role take
+// precedence over instructions given with the `user` role. Messages with the
+// `assistant` role are presumed to have been generated by the model in previous
+// interactions.
+type LabelModelGraderInput struct {
+ // Text inputs to the model - can contain template strings.
+ Content LabelModelGraderInputContentUnion `json:"content,required"`
+ // The role of the message input. One of `user`, `assistant`, `system`, or
+ // `developer`.
+ //
+ // Any of "user", "assistant", "system", "developer".
+ Role string `json:"role,required"`
+ // The type of the message input. Always `message`.
+ //
+ // Any of "message".
+ Type string `json:"type"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Content respjson.Field
+ Role respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r LabelModelGraderInput) RawJSON() string { return r.JSON.raw }
+func (r *LabelModelGraderInput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// LabelModelGraderInputContentUnion contains all possible properties and values
+// from [string], [responses.ResponseInputText],
+// [LabelModelGraderInputContentOutputText].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString]
+type LabelModelGraderInputContentUnion struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ Text string `json:"text"`
+ Type string `json:"type"`
+ JSON struct {
+ OfString respjson.Field
+ Text respjson.Field
+ Type respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u LabelModelGraderInputContentUnion) AsString() (v string) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u LabelModelGraderInputContentUnion) AsInputText() (v responses.ResponseInputText) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u LabelModelGraderInputContentUnion) AsOutputText() (v LabelModelGraderInputContentOutputText) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u LabelModelGraderInputContentUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *LabelModelGraderInputContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A text output from the model.
+type LabelModelGraderInputContentOutputText struct {
+ // The text output from the model.
+ Text string `json:"text,required"`
+ // The type of the output text. Always `output_text`.
+ Type constant.OutputText `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r LabelModelGraderInputContentOutputText) RawJSON() string { return r.JSON.raw }
+func (r *LabelModelGraderInputContentOutputText) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A LabelModelGrader object which uses a model to assign labels to each item in
+// the evaluation.
+//
+// The properties Input, Labels, Model, Name, PassingLabels, Type are required.
+type LabelModelGraderParam struct {
+ Input []LabelModelGraderInputParam `json:"input,omitzero,required"`
+ // The labels to assign to each item in the evaluation.
+ Labels []string `json:"labels,omitzero,required"`
+ // The model to use for the evaluation. Must support structured outputs.
+ Model string `json:"model,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The labels that indicate a passing result. Must be a subset of labels.
+ PassingLabels []string `json:"passing_labels,omitzero,required"`
+ // The object type, which is always `label_model`.
+ //
+ // This field can be elided, and will marshal its zero value as "label_model".
+ Type constant.LabelModel `json:"type,required"`
+ paramObj
+}
+
+func (r LabelModelGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow LabelModelGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *LabelModelGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A message input to the model with a role indicating instruction following
+// hierarchy. Instructions given with the `developer` or `system` role take
+// precedence over instructions given with the `user` role. Messages with the
+// `assistant` role are presumed to have been generated by the model in previous
+// interactions.
+//
+// The properties Content, Role are required.
+type LabelModelGraderInputParam struct {
+ // Text inputs to the model - can contain template strings.
+ Content LabelModelGraderInputContentUnionParam `json:"content,omitzero,required"`
+ // The role of the message input. One of `user`, `assistant`, `system`, or
+ // `developer`.
+ //
+ // Any of "user", "assistant", "system", "developer".
+ Role string `json:"role,omitzero,required"`
+ // The type of the message input. Always `message`.
+ //
+ // Any of "message".
+ Type string `json:"type,omitzero"`
+ paramObj
+}
+
+func (r LabelModelGraderInputParam) MarshalJSON() (data []byte, err error) {
+ type shadow LabelModelGraderInputParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *LabelModelGraderInputParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+func init() {
+ apijson.RegisterFieldValidator[LabelModelGraderInputParam](
+ "role", "user", "assistant", "system", "developer",
+ )
+ apijson.RegisterFieldValidator[LabelModelGraderInputParam](
+ "type", "message",
+ )
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type LabelModelGraderInputContentUnionParam struct {
+ OfString param.Opt[string] `json:",omitzero,inline"`
+ OfInputText *responses.ResponseInputTextParam `json:",omitzero,inline"`
+ OfOutputText *LabelModelGraderInputContentOutputTextParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u LabelModelGraderInputContentUnionParam) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfString, u.OfInputText, u.OfOutputText)
+}
+func (u *LabelModelGraderInputContentUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *LabelModelGraderInputContentUnionParam) asAny() any {
+ if !param.IsOmitted(u.OfString) {
+ return &u.OfString.Value
+ } else if !param.IsOmitted(u.OfInputText) {
+ return u.OfInputText
+ } else if !param.IsOmitted(u.OfOutputText) {
+ return u.OfOutputText
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u LabelModelGraderInputContentUnionParam) GetText() *string {
+ if vt := u.OfInputText; vt != nil {
+ return (*string)(&vt.Text)
+ } else if vt := u.OfOutputText; vt != nil {
+ return (*string)(&vt.Text)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u LabelModelGraderInputContentUnionParam) GetType() *string {
+ if vt := u.OfInputText; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfOutputText; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// A text output from the model.
+//
+// The properties Text, Type are required.
+type LabelModelGraderInputContentOutputTextParam struct {
+ // The text output from the model.
+ Text string `json:"text,required"`
+ // The type of the output text. Always `output_text`.
+ //
+ // This field can be elided, and will marshal its zero value as "output_text".
+ Type constant.OutputText `json:"type,required"`
+ paramObj
+}
+
+func (r LabelModelGraderInputContentOutputTextParam) MarshalJSON() (data []byte, err error) {
+ type shadow LabelModelGraderInputContentOutputTextParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *LabelModelGraderInputContentOutputTextParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A MultiGrader object combines the output of multiple graders to produce a single
+// score.
+type MultiGrader struct {
+ // A formula to calculate the output based on grader results.
+ CalculateOutput string `json:"calculate_output,required"`
+ // A StringCheckGrader object that performs a string comparison between input and
+ // reference using a specified operation.
+ Graders MultiGraderGradersUnion `json:"graders,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The object type, which is always `multi`.
+ Type constant.Multi `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ CalculateOutput respjson.Field
+ Graders respjson.Field
+ Name respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r MultiGrader) RawJSON() string { return r.JSON.raw }
+func (r *MultiGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this MultiGrader to a MultiGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// MultiGraderParam.Overrides()
+func (r MultiGrader) ToParam() MultiGraderParam {
+ return param.Override[MultiGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// MultiGraderGradersUnion contains all possible properties and values from
+// [StringCheckGrader], [TextSimilarityGrader], [PythonGrader], [ScoreModelGrader],
+// [LabelModelGrader].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+type MultiGraderGradersUnion struct {
+ // This field is a union of [string], [string], [[]ScoreModelGraderInput],
+ // [[]LabelModelGraderInput]
+ Input MultiGraderGradersUnionInput `json:"input"`
+ Name string `json:"name"`
+ // This field is from variant [StringCheckGrader].
+ Operation StringCheckGraderOperation `json:"operation"`
+ Reference string `json:"reference"`
+ Type string `json:"type"`
+ // This field is from variant [TextSimilarityGrader].
+ EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric"`
+ // This field is from variant [PythonGrader].
+ Source string `json:"source"`
+ // This field is from variant [PythonGrader].
+ ImageTag string `json:"image_tag"`
+ Model string `json:"model"`
+ // This field is from variant [ScoreModelGrader].
+ Range []float64 `json:"range"`
+ // This field is from variant [ScoreModelGrader].
+ SamplingParams any `json:"sampling_params"`
+ // This field is from variant [LabelModelGrader].
+ Labels []string `json:"labels"`
+ // This field is from variant [LabelModelGrader].
+ PassingLabels []string `json:"passing_labels"`
+ JSON struct {
+ Input respjson.Field
+ Name respjson.Field
+ Operation respjson.Field
+ Reference respjson.Field
+ Type respjson.Field
+ EvaluationMetric respjson.Field
+ Source respjson.Field
+ ImageTag respjson.Field
+ Model respjson.Field
+ Range respjson.Field
+ SamplingParams respjson.Field
+ Labels respjson.Field
+ PassingLabels respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u MultiGraderGradersUnion) AsStringCheckGrader() (v StringCheckGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u MultiGraderGradersUnion) AsTextSimilarityGrader() (v TextSimilarityGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u MultiGraderGradersUnion) AsPythonGrader() (v PythonGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u MultiGraderGradersUnion) AsScoreModelGrader() (v ScoreModelGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u MultiGraderGradersUnion) AsLabelModelGrader() (v LabelModelGrader) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u MultiGraderGradersUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *MultiGraderGradersUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// MultiGraderGradersUnionInput is an implicit subunion of
+// [MultiGraderGradersUnion]. MultiGraderGradersUnionInput provides convenient
+// access to the sub-properties of the union.
+//
+// For type safety it is recommended to directly use a variant of the
+// [MultiGraderGradersUnion].
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString OfScoreModelGraderInputArray
+// OfLabelModelGraderInputArray]
+type MultiGraderGradersUnionInput struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ // This field will be present if the value is a [[]ScoreModelGraderInput] instead
+ // of an object.
+ OfScoreModelGraderInputArray []ScoreModelGraderInput `json:",inline"`
+ // This field will be present if the value is a [[]LabelModelGraderInput] instead
+ // of an object.
+ OfLabelModelGraderInputArray []LabelModelGraderInput `json:",inline"`
+ JSON struct {
+ OfString respjson.Field
+ OfScoreModelGraderInputArray respjson.Field
+ OfLabelModelGraderInputArray respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (r *MultiGraderGradersUnionInput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A MultiGrader object combines the output of multiple graders to produce a single
+// score.
+//
+// The properties CalculateOutput, Graders, Name, Type are required.
+type MultiGraderParam struct {
+ // A formula to calculate the output based on grader results.
+ CalculateOutput string `json:"calculate_output,required"`
+ // A StringCheckGrader object that performs a string comparison between input and
+ // reference using a specified operation.
+ Graders MultiGraderGradersUnionParam `json:"graders,omitzero,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The object type, which is always `multi`.
+ //
+ // This field can be elided, and will marshal its zero value as "multi".
+ Type constant.Multi `json:"type,required"`
+ paramObj
+}
+
+func (r MultiGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow MultiGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *MultiGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type MultiGraderGradersUnionParam struct {
+ OfStringCheckGrader *StringCheckGraderParam `json:",omitzero,inline"`
+ OfTextSimilarityGrader *TextSimilarityGraderParam `json:",omitzero,inline"`
+ OfPythonGrader *PythonGraderParam `json:",omitzero,inline"`
+ OfScoreModelGrader *ScoreModelGraderParam `json:",omitzero,inline"`
+ OfLabelModelGrader *LabelModelGraderParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u MultiGraderGradersUnionParam) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfStringCheckGrader,
+ u.OfTextSimilarityGrader,
+ u.OfPythonGrader,
+ u.OfScoreModelGrader,
+ u.OfLabelModelGrader)
+}
+func (u *MultiGraderGradersUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *MultiGraderGradersUnionParam) asAny() any {
+ if !param.IsOmitted(u.OfStringCheckGrader) {
+ return u.OfStringCheckGrader
+ } else if !param.IsOmitted(u.OfTextSimilarityGrader) {
+ return u.OfTextSimilarityGrader
+ } else if !param.IsOmitted(u.OfPythonGrader) {
+ return u.OfPythonGrader
+ } else if !param.IsOmitted(u.OfScoreModelGrader) {
+ return u.OfScoreModelGrader
+ } else if !param.IsOmitted(u.OfLabelModelGrader) {
+ return u.OfLabelModelGrader
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetOperation() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Operation)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetEvaluationMetric() *string {
+ if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.EvaluationMetric)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetSource() *string {
+ if vt := u.OfPythonGrader; vt != nil {
+ return &vt.Source
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetImageTag() *string {
+ if vt := u.OfPythonGrader; vt != nil && vt.ImageTag.Valid() {
+ return &vt.ImageTag.Value
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetRange() []float64 {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return vt.Range
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetSamplingParams() *any {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return &vt.SamplingParams
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetLabels() []string {
+ if vt := u.OfLabelModelGrader; vt != nil {
+ return vt.Labels
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetPassingLabels() []string {
+ if vt := u.OfLabelModelGrader; vt != nil {
+ return vt.PassingLabels
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetName() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Name)
+ } else if vt := u.OfLabelModelGrader; vt != nil {
+ return (*string)(&vt.Name)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetReference() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Reference)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetType() *string {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfPythonGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfLabelModelGrader; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u MultiGraderGradersUnionParam) GetModel() *string {
+ if vt := u.OfScoreModelGrader; vt != nil {
+ return (*string)(&vt.Model)
+ } else if vt := u.OfLabelModelGrader; vt != nil {
+ return (*string)(&vt.Model)
+ }
+ return nil
+}
+
+// Returns a subunion which exports methods to access subproperties
+//
+// Or use AsAny() to get the underlying value
+func (u MultiGraderGradersUnionParam) GetInput() (res multiGraderGradersUnionParamInput) {
+ if vt := u.OfStringCheckGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfTextSimilarityGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfScoreModelGrader; vt != nil {
+ res.any = &vt.Input
+ } else if vt := u.OfLabelModelGrader; vt != nil {
+ res.any = &vt.Input
+ }
+ return
+}
+
+// Can have the runtime types [*string], [_[]ScoreModelGraderInputParam],
+// [_[]LabelModelGraderInputParam]
+type multiGraderGradersUnionParamInput struct{ any }
+
+// Use the following switch statement to get the type of the union:
+//
+// switch u.AsAny().(type) {
+// case *string:
+// case *[]openai.ScoreModelGraderInputParam:
+// case *[]openai.LabelModelGraderInputParam:
+// default:
+// fmt.Errorf("not present")
+// }
+func (u multiGraderGradersUnionParamInput) AsAny() any { return u.any }
+
+// A PythonGrader object that runs a python script on the input.
+type PythonGrader struct {
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The source code of the python script.
+ Source string `json:"source,required"`
+ // The object type, which is always `python`.
+ Type constant.Python `json:"type,required"`
+ // The image tag to use for the python script.
+ ImageTag string `json:"image_tag"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Name respjson.Field
+ Source respjson.Field
+ Type respjson.Field
+ ImageTag respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r PythonGrader) RawJSON() string { return r.JSON.raw }
+func (r *PythonGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this PythonGrader to a PythonGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// PythonGraderParam.Overrides()
+func (r PythonGrader) ToParam() PythonGraderParam {
+ return param.Override[PythonGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// A PythonGrader object that runs a python script on the input.
+//
+// The properties Name, Source, Type are required.
+type PythonGraderParam struct {
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The source code of the python script.
+ Source string `json:"source,required"`
+ // The image tag to use for the python script.
+ ImageTag param.Opt[string] `json:"image_tag,omitzero"`
+ // The object type, which is always `python`.
+ //
+ // This field can be elided, and will marshal its zero value as "python".
+ Type constant.Python `json:"type,required"`
+ paramObj
+}
+
+func (r PythonGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow PythonGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *PythonGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A ScoreModelGrader object that uses a model to assign a score to the input.
+type ScoreModelGrader struct {
+ // The input text. This may include template strings.
+ Input []ScoreModelGraderInput `json:"input,required"`
+ // The model to use for the evaluation.
+ Model string `json:"model,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The object type, which is always `score_model`.
+ Type constant.ScoreModel `json:"type,required"`
+ // The range of the score. Defaults to `[0, 1]`.
+ Range []float64 `json:"range"`
+ // The sampling parameters for the model.
+ SamplingParams any `json:"sampling_params"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Input respjson.Field
+ Model respjson.Field
+ Name respjson.Field
+ Type respjson.Field
+ Range respjson.Field
+ SamplingParams respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ScoreModelGrader) RawJSON() string { return r.JSON.raw }
+func (r *ScoreModelGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this ScoreModelGrader to a ScoreModelGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// ScoreModelGraderParam.Overrides()
+func (r ScoreModelGrader) ToParam() ScoreModelGraderParam {
+ return param.Override[ScoreModelGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// A message input to the model with a role indicating instruction following
+// hierarchy. Instructions given with the `developer` or `system` role take
+// precedence over instructions given with the `user` role. Messages with the
+// `assistant` role are presumed to have been generated by the model in previous
+// interactions.
+type ScoreModelGraderInput struct {
+ // Text inputs to the model - can contain template strings.
+ Content ScoreModelGraderInputContentUnion `json:"content,required"`
+ // The role of the message input. One of `user`, `assistant`, `system`, or
+ // `developer`.
+ //
+ // Any of "user", "assistant", "system", "developer".
+ Role string `json:"role,required"`
+ // The type of the message input. Always `message`.
+ //
+ // Any of "message".
+ Type string `json:"type"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Content respjson.Field
+ Role respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ScoreModelGraderInput) RawJSON() string { return r.JSON.raw }
+func (r *ScoreModelGraderInput) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ScoreModelGraderInputContentUnion contains all possible properties and values
+// from [string], [responses.ResponseInputText],
+// [ScoreModelGraderInputContentOutputText].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString]
+type ScoreModelGraderInputContentUnion struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ Text string `json:"text"`
+ Type string `json:"type"`
+ JSON struct {
+ OfString respjson.Field
+ Text respjson.Field
+ Type respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ScoreModelGraderInputContentUnion) AsString() (v string) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ScoreModelGraderInputContentUnion) AsInputText() (v responses.ResponseInputText) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ScoreModelGraderInputContentUnion) AsOutputText() (v ScoreModelGraderInputContentOutputText) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ScoreModelGraderInputContentUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *ScoreModelGraderInputContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A text output from the model.
+type ScoreModelGraderInputContentOutputText struct {
+ // The text output from the model.
+ Text string `json:"text,required"`
+ // The type of the output text. Always `output_text`.
+ Type constant.OutputText `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ScoreModelGraderInputContentOutputText) RawJSON() string { return r.JSON.raw }
+func (r *ScoreModelGraderInputContentOutputText) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A ScoreModelGrader object that uses a model to assign a score to the input.
+//
+// The properties Input, Model, Name, Type are required.
+type ScoreModelGraderParam struct {
+ // The input text. This may include template strings.
+ Input []ScoreModelGraderInputParam `json:"input,omitzero,required"`
+ // The model to use for the evaluation.
+ Model string `json:"model,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The range of the score. Defaults to `[0, 1]`.
+ Range []float64 `json:"range,omitzero"`
+ // The sampling parameters for the model.
+ SamplingParams any `json:"sampling_params,omitzero"`
+ // The object type, which is always `score_model`.
+ //
+ // This field can be elided, and will marshal its zero value as "score_model".
+ Type constant.ScoreModel `json:"type,required"`
+ paramObj
+}
+
+func (r ScoreModelGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow ScoreModelGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ScoreModelGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A message input to the model with a role indicating instruction following
+// hierarchy. Instructions given with the `developer` or `system` role take
+// precedence over instructions given with the `user` role. Messages with the
+// `assistant` role are presumed to have been generated by the model in previous
+// interactions.
+//
+// The properties Content, Role are required.
+type ScoreModelGraderInputParam struct {
+ // Text inputs to the model - can contain template strings.
+ Content ScoreModelGraderInputContentUnionParam `json:"content,omitzero,required"`
+ // The role of the message input. One of `user`, `assistant`, `system`, or
+ // `developer`.
+ //
+ // Any of "user", "assistant", "system", "developer".
+ Role string `json:"role,omitzero,required"`
+ // The type of the message input. Always `message`.
+ //
+ // Any of "message".
+ Type string `json:"type,omitzero"`
+ paramObj
+}
+
+func (r ScoreModelGraderInputParam) MarshalJSON() (data []byte, err error) {
+ type shadow ScoreModelGraderInputParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ScoreModelGraderInputParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+func init() {
+ apijson.RegisterFieldValidator[ScoreModelGraderInputParam](
+ "role", "user", "assistant", "system", "developer",
+ )
+ apijson.RegisterFieldValidator[ScoreModelGraderInputParam](
+ "type", "message",
+ )
+}
+
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ScoreModelGraderInputContentUnionParam struct {
+ OfString param.Opt[string] `json:",omitzero,inline"`
+ OfInputText *responses.ResponseInputTextParam `json:",omitzero,inline"`
+ OfOutputText *ScoreModelGraderInputContentOutputTextParam `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ScoreModelGraderInputContentUnionParam) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfString, u.OfInputText, u.OfOutputText)
+}
+func (u *ScoreModelGraderInputContentUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ScoreModelGraderInputContentUnionParam) asAny() any {
+ if !param.IsOmitted(u.OfString) {
+ return &u.OfString.Value
+ } else if !param.IsOmitted(u.OfInputText) {
+ return u.OfInputText
+ } else if !param.IsOmitted(u.OfOutputText) {
+ return u.OfOutputText
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ScoreModelGraderInputContentUnionParam) GetText() *string {
+ if vt := u.OfInputText; vt != nil {
+ return (*string)(&vt.Text)
+ } else if vt := u.OfOutputText; vt != nil {
+ return (*string)(&vt.Text)
+ }
+ return nil
+}
+
+// Returns a pointer to the underlying variant's property, if present.
+func (u ScoreModelGraderInputContentUnionParam) GetType() *string {
+ if vt := u.OfInputText; vt != nil {
+ return (*string)(&vt.Type)
+ } else if vt := u.OfOutputText; vt != nil {
+ return (*string)(&vt.Type)
+ }
+ return nil
+}
+
+// A text output from the model.
+//
+// The properties Text, Type are required.
+type ScoreModelGraderInputContentOutputTextParam struct {
+ // The text output from the model.
+ Text string `json:"text,required"`
+ // The type of the output text. Always `output_text`.
+ //
+ // This field can be elided, and will marshal its zero value as "output_text".
+ Type constant.OutputText `json:"type,required"`
+ paramObj
+}
+
+func (r ScoreModelGraderInputContentOutputTextParam) MarshalJSON() (data []byte, err error) {
+ type shadow ScoreModelGraderInputContentOutputTextParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *ScoreModelGraderInputContentOutputTextParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A StringCheckGrader object that performs a string comparison between input and
+// reference using a specified operation.
+type StringCheckGrader struct {
+ // The input text. This may include template strings.
+ Input string `json:"input,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`.
+ //
+ // Any of "eq", "ne", "like", "ilike".
+ Operation StringCheckGraderOperation `json:"operation,required"`
+ // The reference text. This may include template strings.
+ Reference string `json:"reference,required"`
+ // The object type, which is always `string_check`.
+ Type constant.StringCheck `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Input respjson.Field
+ Name respjson.Field
+ Operation respjson.Field
+ Reference respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r StringCheckGrader) RawJSON() string { return r.JSON.raw }
+func (r *StringCheckGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this StringCheckGrader to a StringCheckGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// StringCheckGraderParam.Overrides()
+func (r StringCheckGrader) ToParam() StringCheckGraderParam {
+ return param.Override[StringCheckGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`.
+type StringCheckGraderOperation string
+
+const (
+ StringCheckGraderOperationEq StringCheckGraderOperation = "eq"
+ StringCheckGraderOperationNe StringCheckGraderOperation = "ne"
+ StringCheckGraderOperationLike StringCheckGraderOperation = "like"
+ StringCheckGraderOperationIlike StringCheckGraderOperation = "ilike"
+)
+
+// A StringCheckGrader object that performs a string comparison between input and
+// reference using a specified operation.
+//
+// The properties Input, Name, Operation, Reference, Type are required.
+type StringCheckGraderParam struct {
+ // The input text. This may include template strings.
+ Input string `json:"input,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`.
+ //
+ // Any of "eq", "ne", "like", "ilike".
+ Operation StringCheckGraderOperation `json:"operation,omitzero,required"`
+ // The reference text. This may include template strings.
+ Reference string `json:"reference,required"`
+ // The object type, which is always `string_check`.
+ //
+ // This field can be elided, and will marshal its zero value as "string_check".
+ Type constant.StringCheck `json:"type,required"`
+ paramObj
+}
+
+func (r StringCheckGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow StringCheckGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *StringCheckGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A TextSimilarityGrader object which grades text based on similarity metrics.
+type TextSimilarityGrader struct {
+ // The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`,
+ // `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`.
+ //
+ // Any of "fuzzy_match", "bleu", "gleu", "meteor", "rouge_1", "rouge_2", "rouge_3",
+ // "rouge_4", "rouge_5", "rouge_l".
+ EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric,required"`
+ // The text being graded.
+ Input string `json:"input,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The text being graded against.
+ Reference string `json:"reference,required"`
+ // The type of grader.
+ Type constant.TextSimilarity `json:"type,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ EvaluationMetric respjson.Field
+ Input respjson.Field
+ Name respjson.Field
+ Reference respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r TextSimilarityGrader) RawJSON() string { return r.JSON.raw }
+func (r *TextSimilarityGrader) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this TextSimilarityGrader to a TextSimilarityGraderParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// TextSimilarityGraderParam.Overrides()
+func (r TextSimilarityGrader) ToParam() TextSimilarityGraderParam {
+ return param.Override[TextSimilarityGraderParam](json.RawMessage(r.RawJSON()))
+}
+
+// The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`,
+// `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`.
+type TextSimilarityGraderEvaluationMetric string
+
+const (
+ TextSimilarityGraderEvaluationMetricFuzzyMatch TextSimilarityGraderEvaluationMetric = "fuzzy_match"
+ TextSimilarityGraderEvaluationMetricBleu TextSimilarityGraderEvaluationMetric = "bleu"
+ TextSimilarityGraderEvaluationMetricGleu TextSimilarityGraderEvaluationMetric = "gleu"
+ TextSimilarityGraderEvaluationMetricMeteor TextSimilarityGraderEvaluationMetric = "meteor"
+ TextSimilarityGraderEvaluationMetricRouge1 TextSimilarityGraderEvaluationMetric = "rouge_1"
+ TextSimilarityGraderEvaluationMetricRouge2 TextSimilarityGraderEvaluationMetric = "rouge_2"
+ TextSimilarityGraderEvaluationMetricRouge3 TextSimilarityGraderEvaluationMetric = "rouge_3"
+ TextSimilarityGraderEvaluationMetricRouge4 TextSimilarityGraderEvaluationMetric = "rouge_4"
+ TextSimilarityGraderEvaluationMetricRouge5 TextSimilarityGraderEvaluationMetric = "rouge_5"
+ TextSimilarityGraderEvaluationMetricRougeL TextSimilarityGraderEvaluationMetric = "rouge_l"
+)
+
+// A TextSimilarityGrader object which grades text based on similarity metrics.
+//
+// The properties EvaluationMetric, Input, Name, Reference, Type are required.
+type TextSimilarityGraderParam struct {
+ // The evaluation metric to use. One of `fuzzy_match`, `bleu`, `gleu`, `meteor`,
+ // `rouge_1`, `rouge_2`, `rouge_3`, `rouge_4`, `rouge_5`, or `rouge_l`.
+ //
+ // Any of "fuzzy_match", "bleu", "gleu", "meteor", "rouge_1", "rouge_2", "rouge_3",
+ // "rouge_4", "rouge_5", "rouge_l".
+ EvaluationMetric TextSimilarityGraderEvaluationMetric `json:"evaluation_metric,omitzero,required"`
+ // The text being graded.
+ Input string `json:"input,required"`
+ // The name of the grader.
+ Name string `json:"name,required"`
+ // The text being graded against.
+ Reference string `json:"reference,required"`
+ // The type of grader.
+ //
+ // This field can be elided, and will marshal its zero value as "text_similarity".
+ Type constant.TextSimilarity `json:"type,required"`
+ paramObj
+}
+
+func (r TextSimilarityGraderParam) MarshalJSON() (data []byte, err error) {
+ type shadow TextSimilarityGraderParam
+ return param.MarshalObject(r, (*shadow)(&r))
+}
+func (r *TextSimilarityGraderParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
@@ -14,7 +14,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
)
// ImageService contains methods and other services that help with interacting with
@@ -36,7 +36,7 @@ func NewImageService(opts ...option.RequestOption) (r ImageService) {
return
}
-// Creates a variation of a given image.
+// Creates a variation of a given image. This endpoint only supports `dall-e-2`.
func (r *ImageService) NewVariation(ctx context.Context, body ImageNewVariationParams, opts ...option.RequestOption) (res *ImagesResponse, err error) {
opts = append(r.Options[:], opts...)
path := "images/variations"
@@ -44,7 +44,8 @@ func (r *ImageService) NewVariation(ctx context.Context, body ImageNewVariationP
return
}
-// Creates an edited or extended image given an original image and a prompt.
+// Creates an edited or extended image given one or more source images and a
+// prompt. This endpoint only supports `gpt-image-1` and `dall-e-2`.
func (r *ImageService) Edit(ctx context.Context, body ImageEditParams, opts ...option.RequestOption) (res *ImagesResponse, err error) {
opts = append(r.Options[:], opts...)
path := "images/edits"
@@ -53,6 +54,7 @@ func (r *ImageService) Edit(ctx context.Context, body ImageEditParams, opts ...o
}
// Creates an image given a prompt.
+// [Learn more](https://platform.openai.com/docs/guides/images).
func (r *ImageService) Generate(ctx context.Context, body ImageGenerateParams, opts ...option.RequestOption) (res *ImagesResponse, err error) {
opts = append(r.Options[:], opts...)
path := "images/generations"
@@ -60,23 +62,24 @@ func (r *ImageService) Generate(ctx context.Context, body ImageGenerateParams, o
return
}
-// Represents the url or the content of an image generated by the OpenAI API.
+// Represents the content or the URL of an image generated by the OpenAI API.
type Image struct {
- // The base64-encoded JSON of the generated image, if `response_format` is
- // `b64_json`.
+ // The base64-encoded JSON of the generated image. Default value for `gpt-image-1`,
+ // and only present if `response_format` is set to `b64_json` for `dall-e-2` and
+ // `dall-e-3`.
B64JSON string `json:"b64_json"`
- // The prompt that was used to generate the image, if there was any revision to the
- // prompt.
+ // For `dall-e-3` only, the revised prompt that was used to generate the image.
RevisedPrompt string `json:"revised_prompt"`
- // The URL of the generated image, if `response_format` is `url` (default).
+ // When using `dall-e-2` or `dall-e-3`, the URL of the generated image if
+ // `response_format` is set to `url` (default value). Unsupported for
+ // `gpt-image-1`.
URL string `json:"url"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- B64JSON resp.Field
- RevisedPrompt resp.Field
- URL resp.Field
- ExtraFields map[string]resp.Field
+ B64JSON respjson.Field
+ RevisedPrompt respjson.Field
+ URL respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -90,20 +93,48 @@ func (r *Image) UnmarshalJSON(data []byte) error {
type ImageModel = string
const (
- ImageModelDallE2 ImageModel = "dall-e-2"
- ImageModelDallE3 ImageModel = "dall-e-3"
+ ImageModelDallE2 ImageModel = "dall-e-2"
+ ImageModelDallE3 ImageModel = "dall-e-3"
+ ImageModelGPTImage1 ImageModel = "gpt-image-1"
)
+// The response from the image generation endpoint.
type ImagesResponse struct {
- Created int64 `json:"created,required"`
- Data []Image `json:"data,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // The Unix timestamp (in seconds) of when the image was created.
+ Created int64 `json:"created,required"`
+ // The background parameter used for the image generation. Either `transparent` or
+ // `opaque`.
+ //
+ // Any of "transparent", "opaque".
+ Background ImagesResponseBackground `json:"background"`
+ // The list of generated images.
+ Data []Image `json:"data"`
+ // The output format of the image generation. Either `png`, `webp`, or `jpeg`.
+ //
+ // Any of "png", "webp", "jpeg".
+ OutputFormat ImagesResponseOutputFormat `json:"output_format"`
+ // The quality of the image generated. Either `low`, `medium`, or `high`.
+ //
+ // Any of "low", "medium", "high".
+ Quality ImagesResponseQuality `json:"quality"`
+ // The size of the image generated. Either `1024x1024`, `1024x1536`, or
+ // `1536x1024`.
+ //
+ // Any of "1024x1024", "1024x1536", "1536x1024".
+ Size ImagesResponseSize `json:"size"`
+ // For `gpt-image-1` only, the token usage information for the image generation.
+ Usage ImagesResponseUsage `json:"usage"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Created resp.Field
- Data resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Created respjson.Field
+ Background respjson.Field
+ Data respjson.Field
+ OutputFormat respjson.Field
+ Quality respjson.Field
+ Size respjson.Field
+ Usage respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -113,12 +144,96 @@ func (r *ImagesResponse) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+// The background parameter used for the image generation. Either `transparent` or
+// `opaque`.
+type ImagesResponseBackground string
+
+const (
+ ImagesResponseBackgroundTransparent ImagesResponseBackground = "transparent"
+ ImagesResponseBackgroundOpaque ImagesResponseBackground = "opaque"
+)
+
+// The output format of the image generation. Either `png`, `webp`, or `jpeg`.
+type ImagesResponseOutputFormat string
+
+const (
+ ImagesResponseOutputFormatPNG ImagesResponseOutputFormat = "png"
+ ImagesResponseOutputFormatWebP ImagesResponseOutputFormat = "webp"
+ ImagesResponseOutputFormatJPEG ImagesResponseOutputFormat = "jpeg"
+)
+
+// The quality of the image generated. Either `low`, `medium`, or `high`.
+type ImagesResponseQuality string
+
+const (
+ ImagesResponseQualityLow ImagesResponseQuality = "low"
+ ImagesResponseQualityMedium ImagesResponseQuality = "medium"
+ ImagesResponseQualityHigh ImagesResponseQuality = "high"
+)
+
+// The size of the image generated. Either `1024x1024`, `1024x1536`, or
+// `1536x1024`.
+type ImagesResponseSize string
+
+const (
+ ImagesResponseSize1024x1024 ImagesResponseSize = "1024x1024"
+ ImagesResponseSize1024x1536 ImagesResponseSize = "1024x1536"
+ ImagesResponseSize1536x1024 ImagesResponseSize = "1536x1024"
+)
+
+// For `gpt-image-1` only, the token usage information for the image generation.
+type ImagesResponseUsage struct {
+ // The number of tokens (images and text) in the input prompt.
+ InputTokens int64 `json:"input_tokens,required"`
+ // The input tokens detailed information for the image generation.
+ InputTokensDetails ImagesResponseUsageInputTokensDetails `json:"input_tokens_details,required"`
+ // The number of image tokens in the output image.
+ OutputTokens int64 `json:"output_tokens,required"`
+ // The total number of tokens (images and text) used for the image generation.
+ TotalTokens int64 `json:"total_tokens,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ InputTokens respjson.Field
+ InputTokensDetails respjson.Field
+ OutputTokens respjson.Field
+ TotalTokens respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ImagesResponseUsage) RawJSON() string { return r.JSON.raw }
+func (r *ImagesResponseUsage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The input tokens detailed information for the image generation.
+type ImagesResponseUsageInputTokensDetails struct {
+ // The number of image tokens in the input prompt.
+ ImageTokens int64 `json:"image_tokens,required"`
+ // The number of text tokens in the input prompt.
+ TextTokens int64 `json:"text_tokens,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ImageTokens respjson.Field
+ TextTokens respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ImagesResponseUsageInputTokensDetails) RawJSON() string { return r.JSON.raw }
+func (r *ImagesResponseUsageInputTokensDetails) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
type ImageNewVariationParams struct {
// The image to use as the basis for the variation(s). Must be a valid PNG file,
// less than 4MB, and square.
- Image io.Reader `json:"image,required" format:"binary"`
- // The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
- // `n=1` is supported.
+ Image io.Reader `json:"image,omitzero,required" format:"binary"`
+ // The number of images to generate. Must be between 1 and 10.
N param.Opt[int64] `json:"n,omitzero"`
// A unique identifier representing your end-user, which can help OpenAI to monitor
// and detect abuse.
@@ -141,14 +256,13 @@ type ImageNewVariationParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageNewVariationParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r ImageNewVariationParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -181,47 +295,81 @@ const (
)
type ImageEditParams struct {
- // The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask
- // is not provided, image must have transparency, which will be used as the mask.
- Image io.Reader `json:"image,required" format:"binary"`
+ // The image(s) to edit. Must be a supported image file or an array of images.
+ //
+ // For `gpt-image-1`, each image should be a `png`, `webp`, or `jpg` file less than
+ // 50MB. You can provide up to 16 images.
+ //
+ // For `dall-e-2`, you can only provide one image, and it should be a square `png`
+ // file less than 4MB.
+ Image ImageEditParamsImageUnion `json:"image,omitzero,required" format:"binary"`
// A text description of the desired image(s). The maximum length is 1000
- // characters.
+ // characters for `dall-e-2`, and 32000 characters for `gpt-image-1`.
Prompt string `json:"prompt,required"`
// The number of images to generate. Must be between 1 and 10.
N param.Opt[int64] `json:"n,omitzero"`
+ // The compression level (0-100%) for the generated images. This parameter is only
+ // supported for `gpt-image-1` with the `webp` or `jpeg` output formats, and
+ // defaults to 100.
+ OutputCompression param.Opt[int64] `json:"output_compression,omitzero"`
// A unique identifier representing your end-user, which can help OpenAI to monitor
// and detect abuse.
// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
User param.Opt[string] `json:"user,omitzero"`
- // The model to use for image generation. Only `dall-e-2` is supported at this
- // time.
+ // Allows to set transparency for the background of the generated image(s). This
+ // parameter is only supported for `gpt-image-1`. Must be one of `transparent`,
+ // `opaque` or `auto` (default value). When `auto` is used, the model will
+ // automatically determine the best background for the image.
+ //
+ // If `transparent`, the output format needs to support transparency, so it should
+ // be set to either `png` (default value) or `webp`.
+ //
+ // Any of "transparent", "opaque", "auto".
+ Background ImageEditParamsBackground `json:"background,omitzero"`
+ // The model to use for image generation. Only `dall-e-2` and `gpt-image-1` are
+ // supported. Defaults to `dall-e-2` unless a parameter specific to `gpt-image-1`
+ // is used.
Model ImageModel `json:"model,omitzero"`
+ // The format in which the generated images are returned. This parameter is only
+ // supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. The
+ // default value is `png`.
+ //
+ // Any of "png", "jpeg", "webp".
+ OutputFormat ImageEditParamsOutputFormat `json:"output_format,omitzero"`
+ // The quality of the image that will be generated. `high`, `medium` and `low` are
+ // only supported for `gpt-image-1`. `dall-e-2` only supports `standard` quality.
+ // Defaults to `auto`.
+ //
+ // Any of "standard", "low", "medium", "high", "auto".
+ Quality ImageEditParamsQuality `json:"quality,omitzero"`
// The format in which the generated images are returned. Must be one of `url` or
// `b64_json`. URLs are only valid for 60 minutes after the image has been
- // generated.
+ // generated. This parameter is only supported for `dall-e-2`, as `gpt-image-1`
+ // will always return base64-encoded images.
//
// Any of "url", "b64_json".
ResponseFormat ImageEditParamsResponseFormat `json:"response_format,omitzero"`
- // The size of the generated images. Must be one of `256x256`, `512x512`, or
- // `1024x1024`.
+ // The size of the generated images. Must be one of `1024x1024`, `1536x1024`
+ // (landscape), `1024x1536` (portrait), or `auto` (default value) for
+ // `gpt-image-1`, and one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`.
//
- // Any of "256x256", "512x512", "1024x1024".
+ // Any of "256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto".
Size ImageEditParamsSize `json:"size,omitzero"`
// An additional image whose fully transparent areas (e.g. where alpha is zero)
- // indicate where `image` should be edited. Must be a valid PNG file, less than
+ // indicate where `image` should be edited. If there are multiple images provided,
+ // the mask will be applied on the first image. Must be a valid PNG file, less than
// 4MB, and have the same dimensions as `image`.
- Mask io.Reader `json:"mask" format:"binary"`
+ Mask io.Reader `json:"mask,omitzero" format:"binary"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageEditParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r ImageEditParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -233,9 +381,74 @@ func (r ImageEditParams) MarshalMultipart() (data []byte, contentType string, er
return buf.Bytes(), writer.FormDataContentType(), nil
}
+// Only one field can be non-zero.
+//
+// Use [param.IsOmitted] to confirm if a field is set.
+type ImageEditParamsImageUnion struct {
+ OfFile io.Reader `json:",omitzero,inline"`
+ OfFileArray []io.Reader `json:",omitzero,inline"`
+ paramUnion
+}
+
+func (u ImageEditParamsImageUnion) MarshalJSON() ([]byte, error) {
+ return param.MarshalUnion(u, u.OfFile, u.OfFileArray)
+}
+func (u *ImageEditParamsImageUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
+}
+
+func (u *ImageEditParamsImageUnion) asAny() any {
+ if !param.IsOmitted(u.OfFile) {
+ return &u.OfFile
+ } else if !param.IsOmitted(u.OfFileArray) {
+ return &u.OfFileArray
+ }
+ return nil
+}
+
+// Allows to set transparency for the background of the generated image(s). This
+// parameter is only supported for `gpt-image-1`. Must be one of `transparent`,
+// `opaque` or `auto` (default value). When `auto` is used, the model will
+// automatically determine the best background for the image.
+//
+// If `transparent`, the output format needs to support transparency, so it should
+// be set to either `png` (default value) or `webp`.
+type ImageEditParamsBackground string
+
+const (
+ ImageEditParamsBackgroundTransparent ImageEditParamsBackground = "transparent"
+ ImageEditParamsBackgroundOpaque ImageEditParamsBackground = "opaque"
+ ImageEditParamsBackgroundAuto ImageEditParamsBackground = "auto"
+)
+
+// The format in which the generated images are returned. This parameter is only
+// supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`. The
+// default value is `png`.
+type ImageEditParamsOutputFormat string
+
+const (
+ ImageEditParamsOutputFormatPNG ImageEditParamsOutputFormat = "png"
+ ImageEditParamsOutputFormatJPEG ImageEditParamsOutputFormat = "jpeg"
+ ImageEditParamsOutputFormatWebP ImageEditParamsOutputFormat = "webp"
+)
+
+// The quality of the image that will be generated. `high`, `medium` and `low` are
+// only supported for `gpt-image-1`. `dall-e-2` only supports `standard` quality.
+// Defaults to `auto`.
+type ImageEditParamsQuality string
+
+const (
+ ImageEditParamsQualityStandard ImageEditParamsQuality = "standard"
+ ImageEditParamsQualityLow ImageEditParamsQuality = "low"
+ ImageEditParamsQualityMedium ImageEditParamsQuality = "medium"
+ ImageEditParamsQualityHigh ImageEditParamsQuality = "high"
+ ImageEditParamsQualityAuto ImageEditParamsQuality = "auto"
+)
+
// The format in which the generated images are returned. Must be one of `url` or
// `b64_json`. URLs are only valid for 60 minutes after the image has been
-// generated.
+// generated. This parameter is only supported for `dall-e-2`, as `gpt-image-1`
+// will always return base64-encoded images.
type ImageEditParamsResponseFormat string
const (
@@ -243,79 +456,159 @@ const (
ImageEditParamsResponseFormatB64JSON ImageEditParamsResponseFormat = "b64_json"
)
-// The size of the generated images. Must be one of `256x256`, `512x512`, or
-// `1024x1024`.
+// The size of the generated images. Must be one of `1024x1024`, `1536x1024`
+// (landscape), `1024x1536` (portrait), or `auto` (default value) for
+// `gpt-image-1`, and one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`.
type ImageEditParamsSize string
const (
ImageEditParamsSize256x256 ImageEditParamsSize = "256x256"
ImageEditParamsSize512x512 ImageEditParamsSize = "512x512"
ImageEditParamsSize1024x1024 ImageEditParamsSize = "1024x1024"
+ ImageEditParamsSize1536x1024 ImageEditParamsSize = "1536x1024"
+ ImageEditParamsSize1024x1536 ImageEditParamsSize = "1024x1536"
+ ImageEditParamsSizeAuto ImageEditParamsSize = "auto"
)
type ImageGenerateParams struct {
- // A text description of the desired image(s). The maximum length is 1000
- // characters for `dall-e-2` and 4000 characters for `dall-e-3`.
+ // A text description of the desired image(s). The maximum length is 32000
+ // characters for `gpt-image-1`, 1000 characters for `dall-e-2` and 4000 characters
+ // for `dall-e-3`.
Prompt string `json:"prompt,required"`
// The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only
// `n=1` is supported.
N param.Opt[int64] `json:"n,omitzero"`
+ // The compression level (0-100%) for the generated images. This parameter is only
+ // supported for `gpt-image-1` with the `webp` or `jpeg` output formats, and
+ // defaults to 100.
+ OutputCompression param.Opt[int64] `json:"output_compression,omitzero"`
// A unique identifier representing your end-user, which can help OpenAI to monitor
// and detect abuse.
// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
User param.Opt[string] `json:"user,omitzero"`
- // The model to use for image generation.
+ // Allows to set transparency for the background of the generated image(s). This
+ // parameter is only supported for `gpt-image-1`. Must be one of `transparent`,
+ // `opaque` or `auto` (default value). When `auto` is used, the model will
+ // automatically determine the best background for the image.
+ //
+ // If `transparent`, the output format needs to support transparency, so it should
+ // be set to either `png` (default value) or `webp`.
+ //
+ // Any of "transparent", "opaque", "auto".
+ Background ImageGenerateParamsBackground `json:"background,omitzero"`
+ // The model to use for image generation. One of `dall-e-2`, `dall-e-3`, or
+ // `gpt-image-1`. Defaults to `dall-e-2` unless a parameter specific to
+ // `gpt-image-1` is used.
Model ImageModel `json:"model,omitzero"`
- // The format in which the generated images are returned. Must be one of `url` or
- // `b64_json`. URLs are only valid for 60 minutes after the image has been
- // generated.
+ // Control the content-moderation level for images generated by `gpt-image-1`. Must
+ // be either `low` for less restrictive filtering or `auto` (default value).
+ //
+ // Any of "low", "auto".
+ Moderation ImageGenerateParamsModeration `json:"moderation,omitzero"`
+ // The format in which the generated images are returned. This parameter is only
+ // supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`.
+ //
+ // Any of "png", "jpeg", "webp".
+ OutputFormat ImageGenerateParamsOutputFormat `json:"output_format,omitzero"`
+ // The quality of the image that will be generated.
+ //
+ // - `auto` (default value) will automatically select the best quality for the
+ // given model.
+ // - `high`, `medium` and `low` are supported for `gpt-image-1`.
+ // - `hd` and `standard` are supported for `dall-e-3`.
+ // - `standard` is the only option for `dall-e-2`.
+ //
+ // Any of "standard", "hd", "low", "medium", "high", "auto".
+ Quality ImageGenerateParamsQuality `json:"quality,omitzero"`
+ // The format in which generated images with `dall-e-2` and `dall-e-3` are
+ // returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes
+ // after the image has been generated. This parameter isn't supported for
+ // `gpt-image-1` which will always return base64-encoded images.
//
// Any of "url", "b64_json".
ResponseFormat ImageGenerateParamsResponseFormat `json:"response_format,omitzero"`
- // The size of the generated images. Must be one of `256x256`, `512x512`, or
- // `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
- // `1024x1792` for `dall-e-3` models.
+ // The size of the generated images. Must be one of `1024x1024`, `1536x1024`
+ // (landscape), `1024x1536` (portrait), or `auto` (default value) for
+ // `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
+ // one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.
//
- // Any of "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792".
+ // Any of "auto", "1024x1024", "1536x1024", "1024x1536", "256x256", "512x512",
+ // "1792x1024", "1024x1792".
Size ImageGenerateParamsSize `json:"size,omitzero"`
- // The style of the generated images. Must be one of `vivid` or `natural`. Vivid
- // causes the model to lean towards generating hyper-real and dramatic images.
- // Natural causes the model to produce more natural, less hyper-real looking
- // images. This param is only supported for `dall-e-3`.
+ // The style of the generated images. This parameter is only supported for
+ // `dall-e-3`. Must be one of `vivid` or `natural`. Vivid causes the model to lean
+ // towards generating hyper-real and dramatic images. Natural causes the model to
+ // produce more natural, less hyper-real looking images.
//
// Any of "vivid", "natural".
Style ImageGenerateParamsStyle `json:"style,omitzero"`
- // The quality of the image that will be generated. `hd` creates images with finer
- // details and greater consistency across the image. This param is only supported
- // for `dall-e-3`.
- //
- // Any of "standard", "hd".
- Quality ImageGenerateParamsQuality `json:"quality,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ImageGenerateParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r ImageGenerateParams) MarshalJSON() (data []byte, err error) {
type shadow ImageGenerateParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ImageGenerateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Allows to set transparency for the background of the generated image(s). This
+// parameter is only supported for `gpt-image-1`. Must be one of `transparent`,
+// `opaque` or `auto` (default value). When `auto` is used, the model will
+// automatically determine the best background for the image.
+//
+// If `transparent`, the output format needs to support transparency, so it should
+// be set to either `png` (default value) or `webp`.
+type ImageGenerateParamsBackground string
+
+const (
+ ImageGenerateParamsBackgroundTransparent ImageGenerateParamsBackground = "transparent"
+ ImageGenerateParamsBackgroundOpaque ImageGenerateParamsBackground = "opaque"
+ ImageGenerateParamsBackgroundAuto ImageGenerateParamsBackground = "auto"
+)
-// The quality of the image that will be generated. `hd` creates images with finer
-// details and greater consistency across the image. This param is only supported
-// for `dall-e-3`.
+// Control the content-moderation level for images generated by `gpt-image-1`. Must
+// be either `low` for less restrictive filtering or `auto` (default value).
+type ImageGenerateParamsModeration string
+
+const (
+ ImageGenerateParamsModerationLow ImageGenerateParamsModeration = "low"
+ ImageGenerateParamsModerationAuto ImageGenerateParamsModeration = "auto"
+)
+
+// The format in which the generated images are returned. This parameter is only
+// supported for `gpt-image-1`. Must be one of `png`, `jpeg`, or `webp`.
+type ImageGenerateParamsOutputFormat string
+
+const (
+ ImageGenerateParamsOutputFormatPNG ImageGenerateParamsOutputFormat = "png"
+ ImageGenerateParamsOutputFormatJPEG ImageGenerateParamsOutputFormat = "jpeg"
+ ImageGenerateParamsOutputFormatWebP ImageGenerateParamsOutputFormat = "webp"
+)
+
+// The quality of the image that will be generated.
+//
+// - `auto` (default value) will automatically select the best quality for the
+// given model.
+// - `high`, `medium` and `low` are supported for `gpt-image-1`.
+// - `hd` and `standard` are supported for `dall-e-3`.
+// - `standard` is the only option for `dall-e-2`.
type ImageGenerateParamsQuality string
const (
ImageGenerateParamsQualityStandard ImageGenerateParamsQuality = "standard"
ImageGenerateParamsQualityHD ImageGenerateParamsQuality = "hd"
+ ImageGenerateParamsQualityLow ImageGenerateParamsQuality = "low"
+ ImageGenerateParamsQualityMedium ImageGenerateParamsQuality = "medium"
+ ImageGenerateParamsQualityHigh ImageGenerateParamsQuality = "high"
+ ImageGenerateParamsQualityAuto ImageGenerateParamsQuality = "auto"
)
-// The format in which the generated images are returned. Must be one of `url` or
-// `b64_json`. URLs are only valid for 60 minutes after the image has been
-// generated.
+// The format in which generated images with `dall-e-2` and `dall-e-3` are
+// returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes
+// after the image has been generated. This parameter isn't supported for
+// `gpt-image-1` which will always return base64-encoded images.
type ImageGenerateParamsResponseFormat string
const (
@@ -323,23 +616,27 @@ const (
ImageGenerateParamsResponseFormatB64JSON ImageGenerateParamsResponseFormat = "b64_json"
)
-// The size of the generated images. Must be one of `256x256`, `512x512`, or
-// `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or
-// `1024x1792` for `dall-e-3` models.
+// The size of the generated images. Must be one of `1024x1024`, `1536x1024`
+// (landscape), `1024x1536` (portrait), or `auto` (default value) for
+// `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
+// one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.
type ImageGenerateParamsSize string
const (
+ ImageGenerateParamsSizeAuto ImageGenerateParamsSize = "auto"
+ ImageGenerateParamsSize1024x1024 ImageGenerateParamsSize = "1024x1024"
+ ImageGenerateParamsSize1536x1024 ImageGenerateParamsSize = "1536x1024"
+ ImageGenerateParamsSize1024x1536 ImageGenerateParamsSize = "1024x1536"
ImageGenerateParamsSize256x256 ImageGenerateParamsSize = "256x256"
ImageGenerateParamsSize512x512 ImageGenerateParamsSize = "512x512"
- ImageGenerateParamsSize1024x1024 ImageGenerateParamsSize = "1024x1024"
ImageGenerateParamsSize1792x1024 ImageGenerateParamsSize = "1792x1024"
ImageGenerateParamsSize1024x1792 ImageGenerateParamsSize = "1024x1792"
)
-// The style of the generated images. Must be one of `vivid` or `natural`. Vivid
-// causes the model to lean towards generating hyper-real and dramatic images.
-// Natural causes the model to produce more natural, less hyper-real looking
-// images. This param is only supported for `dall-e-3`.
+// The style of the generated images. This parameter is only supported for
+// `dall-e-3`. Must be one of `vivid` or `natural`. Vivid causes the model to lean
+// towards generating hyper-real and dramatic images. Natural causes the model to
+// produce more natural, less hyper-real looking images.
type ImageGenerateParamsStyle string
const (
@@ -8,7 +8,7 @@ import (
"net/http/httputil"
"github.com/openai/openai-go/internal/apijson"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
)
// Error represents an error that originates from the API, i.e. when a request is
@@ -19,14 +19,13 @@ type Error struct {
Message string `json:"message,required"`
Param string `json:"param,required"`
Type string `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- Param resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ Param respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
StatusCode int
@@ -13,23 +13,38 @@ import (
"sync"
"time"
- internalparam "github.com/openai/openai-go/internal/param"
"github.com/openai/openai-go/packages/param"
)
var encoders sync.Map // map[encoderEntry]encoderFunc
-func Marshal(value interface{}, writer *multipart.Writer) error {
- e := &encoder{dateFormat: time.RFC3339}
+func Marshal(value any, writer *multipart.Writer) error {
+ e := &encoder{
+ dateFormat: time.RFC3339,
+ arrayFmt: "brackets",
+ }
return e.marshal(value, writer)
}
-func MarshalRoot(value interface{}, writer *multipart.Writer) error {
- e := &encoder{root: true, dateFormat: time.RFC3339}
+func MarshalRoot(value any, writer *multipart.Writer) error {
+ e := &encoder{
+ root: true,
+ dateFormat: time.RFC3339,
+ arrayFmt: "brackets",
+ }
+ return e.marshal(value, writer)
+}
+
+func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error {
+ e := &encoder{
+ arrayFmt: arrayFormat,
+ dateFormat: time.RFC3339,
+ }
return e.marshal(value, writer)
}
type encoder struct {
+ arrayFmt string
dateFormat string
root bool
}
@@ -48,7 +63,7 @@ type encoderEntry struct {
root bool
}
-func (e *encoder) marshal(value interface{}, writer *multipart.Writer) error {
+func (e *encoder) marshal(value any, writer *multipart.Writer) error {
val := reflect.ValueOf(value)
if !val.IsValid() {
return nil
@@ -97,7 +112,7 @@ func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
return e.newTimeTypeEncoder()
}
- if t.ConvertibleTo(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
+ if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) {
return e.newReaderTypeEncoder()
}
e.root = false
@@ -163,15 +178,40 @@ func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
}
}
+func arrayKeyEncoder(arrayFmt string) func(string, int) string {
+ var keyFn func(string, int) string
+ switch arrayFmt {
+ case "comma", "repeat":
+ keyFn = func(k string, _ int) string { return k }
+ case "brackets":
+ keyFn = func(key string, _ int) string { return key + "[]" }
+ case "indices:dots":
+ keyFn = func(k string, i int) string {
+ if k == "" {
+ return strconv.Itoa(i)
+ }
+ return k + "." + strconv.Itoa(i)
+ }
+ case "indices:brackets":
+ keyFn = func(k string, i int) string {
+ if k == "" {
+ return strconv.Itoa(i)
+ }
+ return k + "[" + strconv.Itoa(i) + "]"
+ }
+ }
+ return keyFn
+}
+
func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
itemEncoder := e.typeEncoder(t.Elem())
-
+ keyFn := arrayKeyEncoder(e.arrayFmt)
return func(key string, v reflect.Value, writer *multipart.Writer) error {
- if key != "" {
- key = key + "."
+ if keyFn == nil {
+ return fmt.Errorf("apiform: unsupported array format")
}
for i := 0; i < v.Len(); i++ {
- err := itemEncoder(key+strconv.Itoa(i), v.Index(i), writer)
+ err := itemEncoder(keyFn(key, i), v.Index(i), writer)
if err != nil {
return err
}
@@ -181,12 +221,14 @@ func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
}
func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
- if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
- return e.newFieldTypeEncoder(t)
+ if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
+ return e.newRichFieldTypeEncoder(t)
}
- if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
- return e.newRichFieldTypeEncoder(t, idx)
+ for i := 0; i < t.NumField(); i++ {
+ if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
+ return e.newStructUnionTypeEncoder(t)
+ }
}
encoderFields := []encoderField{}
@@ -284,24 +326,29 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
}
}
-func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
- f, _ := t.FieldByName("Value")
- enc := e.typeEncoder(f.Type)
+var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
- return func(key string, value reflect.Value, writer *multipart.Writer) error {
- present := value.FieldByName("Present")
- if !present.Bool() {
- return nil
- }
- null := value.FieldByName("Null")
- if null.Bool() {
- return nil
+func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
+ var fieldEncoders []encoderFunc
+ for i := 0; i < t.NumField(); i++ {
+ field := t.Field(i)
+ if field.Type == paramUnionType && field.Anonymous {
+ fieldEncoders = append(fieldEncoders, nil)
+ continue
}
- raw := value.FieldByName("Raw")
- if !raw.IsNil() {
- return e.typeEncoder(raw.Type())(key, raw, writer)
+ fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
+ }
+
+ return func(key string, value reflect.Value, writer *multipart.Writer) error {
+ for i := 0; i < t.NumField(); i++ {
+ if value.Field(i).Type() == paramUnionType {
+ continue
+ }
+ if !value.Field(i).IsZero() {
+ return fieldEncoders[i](key, value.Field(i), writer)
+ }
}
- return enc(key, value.FieldByName("Value"), writer)
+ return fmt.Errorf("apiform: union %s has no field set", t.String())
}
}
@@ -330,7 +377,10 @@ func escapeQuotes(s string) string {
func (e *encoder) newReaderTypeEncoder() encoderFunc {
return func(key string, value reflect.Value, writer *multipart.Writer) error {
- reader := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
+ reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader)
+ if !ok {
+ return nil
+ }
filename := "anonymous_file"
contentType := "application/octet-stream"
if named, ok := reader.(interface{ Filename() string }); ok {
@@ -394,8 +444,22 @@ func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipar
return nil
}
-func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
+func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
return func(key string, value reflect.Value, writer *multipart.Writer) error {
return e.encodeMapEntries(key, value, writer)
}
}
+
+func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) {
+ for k, v := range extras {
+ str, ok := v.(string)
+ if !ok {
+ break
+ }
+ err = writer.WriteField(k, str)
+ if err != nil {
+ break
+ }
+ }
+ return
+}
@@ -6,13 +6,13 @@ import (
"reflect"
)
-func (e *encoder) newRichFieldTypeEncoder(t reflect.Type, underlyingValueIdx []int) encoderFunc {
- underlying := t.FieldByIndex(underlyingValueIdx)
- primitiveEncoder := e.newPrimitiveTypeEncoder(underlying.Type)
+func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc {
+ f, _ := t.FieldByName("Value")
+ enc := e.newPrimitiveTypeEncoder(f.Type)
return func(key string, value reflect.Value, writer *multipart.Writer) error {
- if opt, ok := value.Interface().(param.Optional); ok && opt.IsPresent() {
- return primitiveEncoder(key, value.FieldByIndex(underlyingValueIdx), writer)
- } else if ok && opt.IsNull() {
+ if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() {
+ return enc(key, value.FieldByIndex(f.Index), writer)
+ } else if ok && param.IsNull(opt) {
return writer.WriteField(key, "null")
}
return nil
@@ -1,9 +1,13 @@
+// The deserialization algorithm from apijson may be subject to improvements
+// between minor versions, particularly with respect to calling [json.Unmarshal]
+// into param unions.
+
package apijson
import (
"encoding/json"
- "errors"
"fmt"
+ "github.com/openai/openai-go/packages/param"
"reflect"
"strconv"
"sync"
@@ -46,6 +50,7 @@ type decoderBuilder struct {
type decoderState struct {
strict bool
exactness exactness
+ validator *validationEntry
}
// Exactness refers to how close to the type the result was if deserialization
@@ -89,6 +94,18 @@ func (d *decoderBuilder) unmarshal(raw []byte, to any) error {
return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact})
}
+// unmarshalWithExactness is used for internal testing purposes.
+func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) {
+ value := reflect.ValueOf(to).Elem()
+ result := gjson.ParseBytes(raw)
+ if !value.IsValid() {
+ return 0, fmt.Errorf("apijson: cannot marshal into invalid value")
+ }
+ state := decoderState{strict: false, exactness: exact}
+ err := d.typeDecoder(value.Type())(result, value, &state)
+ return state.exactness, err
+}
+
func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
entry := decoderEntry{
Type: t,
@@ -124,6 +141,24 @@ func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc {
return f
}
+// validatedTypeDecoder wraps the type decoder with a validator. This is helpful
+// for ensuring that enum fields are correct.
+func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc {
+ dec := d.typeDecoder(t)
+ if entry == nil {
+ return dec
+ }
+
+ // Thread the current validation entry through the decoder,
+ // but clean up in time for the next field.
+ return func(node gjson.Result, v reflect.Value, state *decoderState) error {
+ state.validator = entry
+ err := dec(node, v, state)
+ state.validator = nil
+ return err
+ }
+}
+
func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error {
return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
}
@@ -139,6 +174,11 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
if t.ConvertibleTo(reflect.TypeOf(time.Time{})) {
return d.newTimeTypeDecoder(t)
}
+
+ if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
+ return d.newOptTypeDecoder(t)
+ }
+
if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
return unmarshalerDecoder
}
@@ -150,6 +190,9 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
d.root = false
if _, ok := unionRegistry[t]; ok {
+ if isStructUnion(t) {
+ return d.newStructUnionDecoder(t)
+ }
return d.newUnionDecoder(t)
}
@@ -173,8 +216,8 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
return nil
}
case reflect.Struct:
- if isEmbeddedUnion(t) {
- return d.newEmbeddedUnionDecoder(t)
+ if isStructUnion(t) {
+ return d.newStructUnionDecoder(t)
}
return d.newStructTypeDecoder(t)
case reflect.Array:
@@ -198,80 +241,6 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
}
}
-// newUnionDecoder returns a decoderFunc that deserializes into a union using an
-// algorithm roughly similar to Pydantic's [smart algorithm].
-//
-// Conceptually this is equivalent to choosing the best schema based on how 'exact'
-// the deserialization is for each of the schemas.
-//
-// If there is a tie in the level of exactness, then the tie is broken
-// left-to-right.
-//
-// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
-func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
- unionEntry, ok := unionRegistry[t]
- if !ok {
- panic("apijson: couldn't find union of type " + t.String() + " in union registry")
- }
- decoders := []decoderFunc{}
- for _, variant := range unionEntry.variants {
- decoder := d.typeDecoder(variant.Type)
- decoders = append(decoders, decoder)
- }
- return func(n gjson.Result, v reflect.Value, state *decoderState) error {
- // If there is a discriminator match, circumvent the exactness logic entirely
- for idx, variant := range unionEntry.variants {
- decoder := decoders[idx]
- if variant.TypeFilter != n.Type {
- continue
- }
-
- if len(unionEntry.discriminatorKey) != 0 {
- discriminatorValue := n.Get(unionEntry.discriminatorKey).Value()
- if discriminatorValue == variant.DiscriminatorValue {
- inner := reflect.New(variant.Type).Elem()
- err := decoder(n, inner, state)
- v.Set(inner)
- return err
- }
- }
- }
-
- // Set bestExactness to worse than loose
- bestExactness := loose - 1
- for idx, variant := range unionEntry.variants {
- decoder := decoders[idx]
- if variant.TypeFilter != n.Type {
- continue
- }
- sub := decoderState{strict: state.strict, exactness: exact}
- inner := reflect.New(variant.Type).Elem()
- err := decoder(n, inner, &sub)
- if err != nil {
- continue
- }
- if sub.exactness == exact {
- v.Set(inner)
- return nil
- }
- if sub.exactness > bestExactness {
- v.Set(inner)
- bestExactness = sub.exactness
- }
- }
-
- if bestExactness < loose {
- return errors.New("apijson: was not able to coerce type as union")
- }
-
- if guardStrict(state, bestExactness != exact) {
- return errors.New("apijson: was not able to coerce type as union strictly")
- }
-
- return nil
- }
-}
-
func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc {
keyType := t.Key()
itemType := t.Elem()
@@ -348,12 +317,23 @@ func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
extraDecoder := (*decoderField)(nil)
var inlineDecoders []decoderField
+ validationEntries := validationRegistry[t]
+
for i := 0; i < t.NumField(); i++ {
idx := []int{i}
field := t.FieldByIndex(idx)
if !field.IsExported() {
continue
}
+
+ var validator *validationEntry
+ for _, entry := range validationEntries {
+ if entry.field.Offset == field.Offset {
+ validator = &entry
+ break
+ }
+ }
+
// If this is an embedded struct, traverse one level deeper to extract
// the fields and get their encoders as well.
if field.Anonymous {
@@ -394,7 +374,13 @@ func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
d.dateFormat = "2006-01-02"
}
}
- decoderFields[ptag.name] = decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name}
+
+ decoderFields[ptag.name] = decoderField{
+ ptag,
+ d.validatedTypeDecoder(field.Type, validator),
+ idx, field.Name,
+ }
+
d.dateFormat = oldFormat
}
@@ -474,6 +460,12 @@ func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc {
}
}
+ // Handle null [param.Opt]
+ if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
+ dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw))
+ continue
+ }
+
if itemNode.Type == gjson.Null {
meta = Field{
raw: itemNode.Raw,
@@ -530,6 +522,9 @@ func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
if n.Type == gjson.JSON {
return fmt.Errorf("apijson: failed to parse string")
}
+
+ state.validateString(v)
+
if guardUnknown(state, v) {
return fmt.Errorf("apijson: failed string enum validation")
}
@@ -546,6 +541,9 @@ func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON {
return fmt.Errorf("apijson: failed to parse bool")
}
+
+ state.validateBool(v)
+
if guardUnknown(state, v) {
return fmt.Errorf("apijson: failed bool enum validation")
}
@@ -562,6 +560,9 @@ func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) {
return fmt.Errorf("apijson: failed to parse int")
}
+
+ state.validateInt(v)
+
if guardUnknown(state, v) {
return fmt.Errorf("apijson: failed int enum validation")
}
@@ -606,6 +607,17 @@ func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc {
}
}
+func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc {
+ for t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+ valueField, _ := t.FieldByName("Value")
+ return func(n gjson.Result, v reflect.Value, state *decoderState) error {
+ state.validateOptKind(n, valueField.Type)
+ return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw))
+ }
+}
+
func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
format := d.dateFormat
return func(n gjson.Result, v reflect.Value, state *decoderState) error {
@@ -641,7 +653,7 @@ func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc {
}
}
-func setUnexportedField(field reflect.Value, value interface{}) {
+func setUnexportedField(field reflect.Value, value any) {
reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value))
}
@@ -12,18 +12,16 @@ import (
"time"
"github.com/tidwall/sjson"
-
- "github.com/openai/openai-go/internal/param"
)
var encoders sync.Map // map[encoderEntry]encoderFunc
-func Marshal(value interface{}) ([]byte, error) {
+func Marshal(value any) ([]byte, error) {
e := &encoder{dateFormat: time.RFC3339}
return e.marshal(value)
}
-func MarshalRoot(value interface{}) ([]byte, error) {
+func MarshalRoot(value any) ([]byte, error) {
e := &encoder{root: true, dateFormat: time.RFC3339}
return e.marshal(value)
}
@@ -47,7 +45,7 @@ type encoderEntry struct {
root bool
}
-func (e *encoder) marshal(value interface{}) ([]byte, error) {
+func (e *encoder) marshal(value any) ([]byte, error) {
val := reflect.ValueOf(value)
if !val.IsValid() {
return nil, nil
@@ -202,10 +200,6 @@ func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
}
func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
- if t.Implements(reflect.TypeOf((*param.FieldLike)(nil)).Elem()) {
- return e.newFieldTypeEncoder(t)
- }
-
encoderFields := []encoderField{}
extraEncoder := (*encoderField)(nil)
@@ -3,7 +3,10 @@ package apijson
import (
"fmt"
"reflect"
+ "slices"
"sync"
+
+ "github.com/tidwall/gjson"
)
/********************/
@@ -12,8 +15,13 @@ import (
type validationEntry struct {
field reflect.StructField
- nullable bool
- legalValues []reflect.Value
+ required bool
+ legalValues struct {
+ strings []string
+ // 1 represents true, 0 represents false, -1 represents either
+ bools int
+ ints []int64
+ }
}
type validatorFunc func(reflect.Value) exactness
@@ -21,7 +29,7 @@ type validatorFunc func(reflect.Value) exactness
var validators sync.Map
var validationRegistry = map[reflect.Type][]validationEntry{}
-func RegisterFieldValidator[T any, V string | bool | int](fieldName string, nullable bool, values ...V) {
+func RegisterFieldValidator[T any, V string | bool | int](fieldName string, values ...V) {
var t T
parentType := reflect.TypeOf(t)
@@ -34,14 +42,45 @@ func RegisterFieldValidator[T any, V string | bool | int](fieldName string, null
if parentType.Kind() != reflect.Struct {
panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String()))
}
- field, found := parentType.FieldByName(fieldName)
+
+ var field reflect.StructField
+ found := false
+ for i := 0; i < parentType.NumField(); i++ {
+ ptag, ok := parseJSONStructTag(parentType.Field(i))
+ if ok && ptag.name == fieldName {
+ field = parentType.Field(i)
+ found = true
+ break
+ }
+ }
+
if !found {
- panic(fmt.Sprintf("apijson: cannot initialize validator for unknown field %q in %s", fieldName, parentType.String()))
+ panic(fmt.Sprintf("apijson: cannot find field %s in struct %s", fieldName, parentType.String()))
}
- newEntry := validationEntry{field, nullable, make([]reflect.Value, len(values))}
- for i, value := range values {
- newEntry.legalValues[i] = reflect.ValueOf(value)
+ newEntry := validationEntry{field: field}
+ newEntry.legalValues.bools = -1 // default to either
+
+ switch values := any(values).(type) {
+ case []string:
+ newEntry.legalValues.strings = values
+ case []int:
+ newEntry.legalValues.ints = make([]int64, len(values))
+ for i, value := range values {
+ newEntry.legalValues.ints[i] = int64(value)
+ }
+ case []bool:
+ for i, value := range values {
+ var next int
+ if value {
+ next = 1
+ }
+ if i > 0 && newEntry.legalValues.bools != next {
+ newEntry.legalValues.bools = -1 // accept either
+ break
+ }
+ newEntry.legalValues.bools = next
+ }
}
// Store the information necessary to create a validator, so that we can use it
@@ -49,39 +88,58 @@ func RegisterFieldValidator[T any, V string | bool | int](fieldName string, null
validationRegistry[parentType] = append(validationRegistry[parentType], newEntry)
}
-// Enums are the only types which are validated
-func typeValidator(t reflect.Type) validatorFunc {
- entry, ok := validationRegistry[t]
- if !ok {
- return nil
+func (state *decoderState) validateString(v reflect.Value) {
+ if state.validator == nil {
+ return
}
-
- if fi, ok := validators.Load(t); ok {
- return fi.(validatorFunc)
+ if !slices.Contains(state.validator.legalValues.strings, v.String()) {
+ state.exactness = loose
}
+}
- fi, _ := validators.LoadOrStore(t, validatorFunc(func(v reflect.Value) exactness {
- return validateEnum(v, entry)
- }))
- return fi.(validatorFunc)
+func (state *decoderState) validateInt(v reflect.Value) {
+ if state.validator == nil {
+ return
+ }
+ if !slices.Contains(state.validator.legalValues.ints, v.Int()) {
+ state.exactness = loose
+ }
}
-func validateEnum(v reflect.Value, entry []validationEntry) exactness {
- if v.Kind() != reflect.Struct {
- return loose
+func (state *decoderState) validateBool(v reflect.Value) {
+ if state.validator == nil {
+ return
}
+ b := v.Bool()
+ if state.validator.legalValues.bools == 1 && b == false {
+ state.exactness = loose
+ } else if state.validator.legalValues.bools == 0 && b == true {
+ state.exactness = loose
+ }
+}
- for _, check := range entry {
- field := v.FieldByIndex(check.field.Index)
- if !field.IsValid() {
- return loose
+func (state *decoderState) validateOptKind(node gjson.Result, t reflect.Type) {
+ switch node.Type {
+ case gjson.JSON:
+ state.exactness = loose
+ case gjson.Null:
+ return
+ case gjson.False, gjson.True:
+ if t.Kind() != reflect.Bool {
+ state.exactness = loose
}
- for _, opt := range check.legalValues {
- if field.Equal(opt) {
- return exact
- }
+ case gjson.Number:
+ switch t.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+ reflect.Float32, reflect.Float64:
+ return
+ default:
+ state.exactness = loose
+ }
+ case gjson.String:
+ if t.Kind() != reflect.String {
+ state.exactness = loose
}
}
-
- return loose
}
@@ -8,18 +8,27 @@ import (
type UnionVariant struct {
TypeFilter gjson.Type
- DiscriminatorValue interface{}
+ DiscriminatorValue any
Type reflect.Type
}
var unionRegistry = map[reflect.Type]unionEntry{}
-var unionVariants = map[reflect.Type]interface{}{}
+var unionVariants = map[reflect.Type]any{}
type unionEntry struct {
discriminatorKey string
variants []UnionVariant
}
+func Discriminator[T any](value any) UnionVariant {
+ var zero T
+ return UnionVariant{
+ TypeFilter: gjson.JSON,
+ DiscriminatorValue: value,
+ Type: reflect.TypeOf(zero),
+ }
+}
+
func RegisterUnion[T any](discriminator string, variants ...UnionVariant) {
typ := reflect.TypeOf((*T)(nil)).Elem()
unionRegistry[typ] = unionEntry{
@@ -1,7 +1,7 @@
package apijson
import (
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"reflect"
)
@@ -45,7 +45,7 @@ func setMetadataExtraFields(root reflect.Value, index []int, name string, metaEx
return
}
- newMap := make(map[string]resp.Field, len(metaExtras))
+ newMap := make(map[string]respjson.Field, len(metaExtras))
if target.Type() == reflect.TypeOf(newMap) {
for k, v := range metaExtras {
newMap[k] = v.toRespField()
@@ -54,14 +54,14 @@ func setMetadataExtraFields(root reflect.Value, index []int, name string, metaEx
}
}
-func (f Field) toRespField() resp.Field {
- if f.IsNull() {
- return resp.NewNullField()
- } else if f.IsMissing() {
- return resp.Field{}
+func (f Field) toRespField() respjson.Field {
+ if f.IsMissing() {
+ return respjson.Field{}
+ } else if f.IsNull() {
+ return respjson.NewField("null")
} else if f.IsInvalid() {
- return resp.NewInvalidField(f.raw)
+ return respjson.NewInvalidField(f.raw)
} else {
- return resp.NewValidField(f.raw)
+ return respjson.NewField(f.raw)
}
}
@@ -8,10 +8,14 @@ import (
"github.com/tidwall/gjson"
)
-func isEmbeddedUnion(t reflect.Type) bool {
- var apiunion param.APIUnion
+var apiUnionType = reflect.TypeOf(param.APIUnion{})
+
+func isStructUnion(t reflect.Type) bool {
+ if t.Kind() != reflect.Struct {
+ return false
+ }
for i := 0; i < t.NumField(); i++ {
- if t.Field(i).Type == reflect.TypeOf(apiunion) && t.Field(i).Anonymous {
+ if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous {
return true
}
}
@@ -33,19 +37,116 @@ func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.T
unionRegistry[reflect.TypeOf(t)] = entry
}
-func (d *decoderBuilder) newEmbeddedUnionDecoder(t reflect.Type) decoderFunc {
- decoders := []decoderFunc{}
+func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc {
+ type variantDecoder struct {
+ decoder decoderFunc
+ field reflect.StructField
+ discriminatorValue any
+ }
+ variants := []variantDecoder{}
for i := 0; i < t.NumField(); i++ {
- variant := t.Field(i)
- decoder := d.typeDecoder(variant.Type)
- decoders = append(decoders, decoder)
+ field := t.Field(i)
+
+ if field.Anonymous && field.Type == apiUnionType {
+ continue
+ }
+
+ decoder := d.typeDecoder(field.Type)
+ variants = append(variants, variantDecoder{
+ decoder: decoder,
+ field: field,
+ })
+ }
+
+ unionEntry, discriminated := unionRegistry[t]
+ for _, unionVariant := range unionEntry.variants {
+ for i := 0; i < len(variants); i++ {
+ variant := &variants[i]
+ if variant.field.Type.Elem() == unionVariant.Type {
+ variant.discriminatorValue = unionVariant.DiscriminatorValue
+ break
+ }
+ }
}
- unionEntry := unionEntry{
- variants: []UnionVariant{},
+ return func(n gjson.Result, v reflect.Value, state *decoderState) error {
+ if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 {
+ discriminator := n.Get(unionEntry.discriminatorKey).Value()
+ for _, variant := range variants {
+ if discriminator == variant.discriminatorValue {
+ inner := v.FieldByIndex(variant.field.Index)
+ return variant.decoder(n, inner, state)
+ }
+ }
+ return errors.New("apijson: was not able to find discriminated union variant")
+ }
+
+ // Set bestExactness to worse than loose
+ bestExactness := loose - 1
+ bestVariant := -1
+ for i, variant := range variants {
+ // Pointers are used to discern JSON object variants from value variants
+ if n.Type != gjson.JSON && variant.field.Type.Kind() == reflect.Ptr {
+ continue
+ }
+
+ sub := decoderState{strict: state.strict, exactness: exact}
+ inner := v.FieldByIndex(variant.field.Index)
+ err := variant.decoder(n, inner, &sub)
+ if err != nil {
+ continue
+ }
+ if sub.exactness == exact {
+ bestExactness = exact
+ bestVariant = i
+ break
+ }
+ if sub.exactness > bestExactness {
+ bestExactness = sub.exactness
+ bestVariant = i
+ }
+ }
+
+ if bestExactness < loose {
+ return errors.New("apijson: was not able to coerce type as union")
+ }
+
+ if guardStrict(state, bestExactness != exact) {
+ return errors.New("apijson: was not able to coerce type as union strictly")
+ }
+
+ for i := 0; i < len(variants); i++ {
+ if i == bestVariant {
+ continue
+ }
+ v.FieldByIndex(variants[i].field.Index).SetZero()
+ }
+
+ return nil
}
+}
+// newUnionDecoder returns a decoderFunc that deserializes into a union using an
+// algorithm roughly similar to Pydantic's [smart algorithm].
+//
+// Conceptually this is equivalent to choosing the best schema based on how 'exact'
+// the deserialization is for each of the schemas.
+//
+// If there is a tie in the level of exactness, then the tie is broken
+// left-to-right.
+//
+// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode
+func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc {
+ unionEntry, ok := unionRegistry[t]
+ if !ok {
+ panic("apijson: couldn't find union of type " + t.String() + " in union registry")
+ }
+ decoders := []decoderFunc{}
+ for _, variant := range unionEntry.variants {
+ decoder := d.typeDecoder(variant.Type)
+ decoders = append(decoders, decoder)
+ }
return func(n gjson.Result, v reflect.Value, state *decoderState) error {
// If there is a discriminator match, circumvent the exactness logic entirely
for idx, variant := range unionEntry.variants {
@@ -9,7 +9,6 @@ import (
"sync"
"time"
- internalparam "github.com/openai/openai-go/internal/param"
"github.com/openai/openai-go/packages/param"
)
@@ -21,7 +20,7 @@ type encoder struct {
settings QuerySettings
}
-type encoderFunc func(key string, value reflect.Value) []Pair
+type encoderFunc func(key string, value reflect.Value) ([]Pair, error)
type encoderField struct {
tag parsedStructTag
@@ -62,7 +61,7 @@ func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
f encoderFunc
)
wg.Add(1)
- fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) []Pair {
+ fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) {
wg.Wait()
return f(key, v)
}))
@@ -77,9 +76,12 @@ func (e *encoder) typeEncoder(t reflect.Type) encoderFunc {
return f
}
-func marshalerEncoder(key string, value reflect.Value) []Pair {
- s, _ := value.Interface().(json.Marshaler).MarshalJSON()
- return []Pair{{key, string(s)}}
+func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) {
+ s, err := value.Interface().(json.Marshaler).MarshalJSON()
+ if err != nil {
+ return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err)
+ }
+ return []Pair{{key, string(s)}}, nil
}
func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
@@ -87,19 +89,23 @@ func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
return e.newTimeTypeEncoder(t)
}
- if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) && param.OptionalPrimitiveTypes[t] == nil {
+ if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
+ return e.newRichFieldTypeEncoder(t)
+ }
+
+ if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) {
return marshalerEncoder
}
+
e.root = false
switch t.Kind() {
case reflect.Pointer:
encoder := e.typeEncoder(t.Elem())
- return func(key string, value reflect.Value) (pairs []Pair) {
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
if !value.IsValid() || value.IsNil() {
return
}
- pairs = encoder(key, value.Elem())
- return
+ return encoder(key, value.Elem())
}
case reflect.Struct:
return e.newStructTypeEncoder(t)
@@ -117,12 +123,14 @@ func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc {
}
func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
- if t.Implements(reflect.TypeOf((*internalparam.FieldLike)(nil)).Elem()) {
- return e.newFieldTypeEncoder(t)
+ if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
+ return e.newRichFieldTypeEncoder(t)
}
- if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
- return e.newRichFieldTypeEncoder(t, idx)
+ for i := 0; i < t.NumField(); i++ {
+ if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
+ return e.newStructUnionTypeEncoder(t)
+ }
}
encoderFields := []encoderField{}
@@ -168,9 +176,9 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
var encoderFn encoderFunc
if ptag.omitzero {
typeEncoderFn := e.typeEncoder(field.Type)
- encoderFn = func(key string, value reflect.Value) []Pair {
+ encoderFn = func(key string, value reflect.Value) ([]Pair, error) {
if value.IsZero() {
- return nil
+ return nil, nil
}
return typeEncoderFn(key, value)
}
@@ -183,7 +191,7 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
}
collectEncoderFields(t, []int{})
- return func(key string, value reflect.Value) (pairs []Pair) {
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
for _, ef := range encoderFields {
var subkey string = e.renderKeyPath(key, ef.tag.name)
if ef.tag.inline {
@@ -191,25 +199,62 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
}
field := value.FieldByIndex(ef.idx)
- pairs = append(pairs, ef.fn(subkey, field)...)
+ subpairs, suberr := ef.fn(subkey, field)
+ if suberr != nil {
+ err = suberr
+ }
+ pairs = append(pairs, subpairs...)
}
return
}
}
+var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
+
+func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
+ var fieldEncoders []encoderFunc
+ for i := 0; i < t.NumField(); i++ {
+ field := t.Field(i)
+ if field.Type == paramUnionType && field.Anonymous {
+ fieldEncoders = append(fieldEncoders, nil)
+ continue
+ }
+ fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
+ }
+
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
+ for i := 0; i < t.NumField(); i++ {
+ if value.Field(i).Type() == paramUnionType {
+ continue
+ }
+ if !value.Field(i).IsZero() {
+ return fieldEncoders[i](key, value.Field(i))
+ }
+ }
+ return nil, fmt.Errorf("apiquery: union %s has no field set", t.String())
+ }
+}
+
func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
keyEncoder := e.typeEncoder(t.Key())
elementEncoder := e.typeEncoder(t.Elem())
- return func(key string, value reflect.Value) (pairs []Pair) {
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
iter := value.MapRange()
for iter.Next() {
- encodedKey := keyEncoder("", iter.Key())
+ encodedKey, err := keyEncoder("", iter.Key())
+ if err != nil {
+ return nil, err
+ }
if len(encodedKey) != 1 {
- panic("Unexpected number of parts for encoded map key. Are you using a non-primitive for this map?")
+ return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive")
}
subkey := encodedKey[0].value
keyPath := e.renderKeyPath(key, subkey)
- pairs = append(pairs, elementEncoder(keyPath, iter.Value())...)
+ subpairs, suberr := elementEncoder(keyPath, iter.Value())
+ if suberr != nil {
+ err = suberr
+ }
+ pairs = append(pairs, subpairs...)
}
return
}
@@ -229,36 +274,48 @@ func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc {
switch e.settings.ArrayFormat {
case ArrayQueryFormatComma:
innerEncoder := e.typeEncoder(t.Elem())
- return func(key string, v reflect.Value) []Pair {
+ return func(key string, v reflect.Value) ([]Pair, error) {
elements := []string{}
for i := 0; i < v.Len(); i++ {
- for _, pair := range innerEncoder("", v.Index(i)) {
+ innerPairs, err := innerEncoder("", v.Index(i))
+ if err != nil {
+ return nil, err
+ }
+ for _, pair := range innerPairs {
elements = append(elements, pair.value)
}
}
if len(elements) == 0 {
- return []Pair{}
+ return []Pair{}, nil
}
- return []Pair{{key, strings.Join(elements, ",")}}
+ return []Pair{{key, strings.Join(elements, ",")}}, nil
}
case ArrayQueryFormatRepeat:
innerEncoder := e.typeEncoder(t.Elem())
- return func(key string, value reflect.Value) (pairs []Pair) {
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
for i := 0; i < value.Len(); i++ {
- pairs = append(pairs, innerEncoder(key, value.Index(i))...)
+ subpairs, suberr := innerEncoder(key, value.Index(i))
+ if suberr != nil {
+ err = suberr
+ }
+ pairs = append(pairs, subpairs...)
}
- return pairs
+ return
}
case ArrayQueryFormatIndices:
panic("The array indices format is not supported yet")
case ArrayQueryFormatBrackets:
innerEncoder := e.typeEncoder(t.Elem())
- return func(key string, value reflect.Value) []Pair {
- pairs := []Pair{}
+ return func(key string, value reflect.Value) (pairs []Pair, err error) {
+ pairs = []Pair{}
for i := 0; i < value.Len(); i++ {
- pairs = append(pairs, innerEncoder(key+"[]", value.Index(i))...)
+ subpairs, suberr := innerEncoder(key+"[]", value.Index(i))
+ if suberr != nil {
+ err = suberr
+ }
+ pairs = append(pairs, subpairs...)
}
- return pairs
+ return
}
default:
panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat))
@@ -271,46 +328,46 @@ func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc {
inner := t.Elem()
innerEncoder := e.newPrimitiveTypeEncoder(inner)
- return func(key string, v reflect.Value) []Pair {
+ return func(key string, v reflect.Value) ([]Pair, error) {
if !v.IsValid() || v.IsNil() {
- return nil
+ return nil, nil
}
return innerEncoder(key, v.Elem())
}
case reflect.String:
- return func(key string, v reflect.Value) []Pair {
- return []Pair{{key, v.String()}}
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return []Pair{{key, v.String()}}, nil
}
case reflect.Bool:
- return func(key string, v reflect.Value) []Pair {
+ return func(key string, v reflect.Value) ([]Pair, error) {
if v.Bool() {
- return []Pair{{key, "true"}}
+ return []Pair{{key, "true"}}, nil
}
- return []Pair{{key, "false"}}
+ return []Pair{{key, "false"}}, nil
}
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64:
- return func(key string, v reflect.Value) []Pair {
- return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil
}
case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64:
- return func(key string, v reflect.Value) []Pair {
- return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil
}
case reflect.Float32, reflect.Float64:
- return func(key string, v reflect.Value) []Pair {
- return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil
}
case reflect.Complex64, reflect.Complex128:
bitSize := 64
if t.Kind() == reflect.Complex128 {
bitSize = 128
}
- return func(key string, v reflect.Value) []Pair {
- return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil
}
default:
- return func(key string, v reflect.Value) []Pair {
- return nil
+ return func(key string, v reflect.Value) ([]Pair, error) {
+ return nil, nil
}
}
}
@@ -319,15 +376,14 @@ func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
f, _ := t.FieldByName("Value")
enc := e.typeEncoder(f.Type)
- return func(key string, value reflect.Value) []Pair {
+ return func(key string, value reflect.Value) ([]Pair, error) {
present := value.FieldByName("Present")
if !present.Bool() {
- return nil
+ return nil, nil
}
null := value.FieldByName("Null")
if null.Bool() {
- // TODO: Error?
- return nil
+ return nil, fmt.Errorf("apiquery: field cannot be null")
}
raw := value.FieldByName("Raw")
if !raw.IsNil() {
@@ -337,21 +393,21 @@ func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
}
}
-func (e *encoder) newTimeTypeEncoder(t reflect.Type) encoderFunc {
+func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc {
format := e.dateFormat
- return func(key string, value reflect.Value) []Pair {
+ return func(key string, value reflect.Value) ([]Pair, error) {
return []Pair{{
key,
value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format),
- }}
+ }}, nil
}
}
func (e encoder) newInterfaceEncoder() encoderFunc {
- return func(key string, value reflect.Value) []Pair {
+ return func(key string, value reflect.Value) ([]Pair, error) {
value = value.Elem()
if !value.IsValid() {
- return nil
+ return nil, nil
}
return e.typeEncoder(value.Type())(key, value)
}
@@ -6,26 +6,31 @@ import (
"time"
)
-func MarshalWithSettings(value interface{}, settings QuerySettings) url.Values {
+func MarshalWithSettings(value any, settings QuerySettings) (url.Values, error) {
e := encoder{time.RFC3339, true, settings}
kv := url.Values{}
val := reflect.ValueOf(value)
if !val.IsValid() {
- return nil
+ return nil, nil
}
typ := val.Type()
- for _, pair := range e.typeEncoder(typ)("", val) {
+
+ pairs, err := e.typeEncoder(typ)("", val)
+ if err != nil {
+ return nil, err
+ }
+ for _, pair := range pairs {
kv.Add(pair.key, pair.value)
}
- return kv
+ return kv, nil
}
-func Marshal(value interface{}) url.Values {
+func Marshal(value any) (url.Values, error) {
return MarshalWithSettings(value, QuerySettings{})
}
type Queryer interface {
- URLQuery() url.Values
+ URLQuery() (url.Values, error)
}
type QuerySettings struct {
@@ -6,15 +6,15 @@ import (
"github.com/openai/openai-go/packages/param"
)
-func (e *encoder) newRichFieldTypeEncoder(t reflect.Type, underlyingValueIdx []int) encoderFunc {
- underlying := t.FieldByIndex(underlyingValueIdx)
- primitiveEncoder := e.newPrimitiveTypeEncoder(underlying.Type)
- return func(key string, value reflect.Value) []Pair {
- if fielder, ok := value.Interface().(param.Optional); ok && fielder.IsPresent() {
- return primitiveEncoder(key, value.FieldByIndex(underlyingValueIdx))
- } else if ok && fielder.IsNull() {
- return []Pair{{key, "null"}}
+func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc {
+ f, _ := t.FieldByName("Value")
+ enc := e.typeEncoder(f.Type)
+ return func(key string, value reflect.Value) ([]Pair, error) {
+ if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() {
+ return enc(key, value.FieldByIndex(f.Index))
+ } else if ok && param.IsNull(opt) {
+ return []Pair{{key, "null"}}, nil
}
- return nil
+ return nil, nil
}
}
@@ -776,7 +776,7 @@ type mapEncoder struct {
}
func (me mapEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
- if v.IsNil() {
+ if v.IsNil() /* EDIT(begin) */ || sentinel.IsValueNull(v) /* EDIT(end) */ {
e.WriteString("null")
return
}
@@ -855,7 +855,7 @@ type sliceEncoder struct {
}
func (se sliceEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
- if v.IsNil() {
+ if v.IsNil() /* EDIT(begin) */ || sentinel.IsValueNull(v) /* EDIT(end) */ {
e.WriteString("null")
return
}
@@ -916,14 +916,7 @@ type ptrEncoder struct {
}
func (pe ptrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) {
- // EDIT(begin)
- //
- // if v.IsNil() {
- // e.WriteString("null")
- // return
- // }
-
- if v.IsNil() || sentinel.IsValueNullPtr(v) || sentinel.IsValueNullSlice(v) {
+ if v.IsNil() {
e.WriteString("null")
return
}
@@ -6,52 +6,41 @@ import (
"sync"
)
-var nullPtrsCache sync.Map // map[reflect.Type]*T
-
-func NullPtr[T any]() *T {
- t := shims.TypeFor[T]()
- ptr, loaded := nullPtrsCache.Load(t) // avoid premature allocation
- if !loaded {
- ptr, _ = nullPtrsCache.LoadOrStore(t, new(T))
- }
- return (ptr.(*T))
+type cacheEntry struct {
+ x any
+ ptr uintptr
+ kind reflect.Kind
}
-var nullSlicesCache sync.Map // map[reflect.Type][]T
+var nullCache sync.Map // map[reflect.Type]cacheEntry
-func NullSlice[T any]() []T {
+func NewNullSentinel[T any](mk func() T) T {
t := shims.TypeFor[T]()
- slice, loaded := nullSlicesCache.Load(t) // avoid premature allocation
+ entry, loaded := nullCache.Load(t) // avoid premature allocation
if !loaded {
- slice, _ = nullSlicesCache.LoadOrStore(t, []T{})
+ x := mk()
+ ptr := reflect.ValueOf(x).Pointer()
+ entry, _ = nullCache.LoadOrStore(t, cacheEntry{x, ptr, t.Kind()})
}
- return slice.([]T)
+ return entry.(cacheEntry).x.(T)
}
-func IsNullPtr[T any](ptr *T) bool {
- nullptr, ok := nullPtrsCache.Load(shims.TypeFor[T]())
- return ok && ptr == nullptr.(*T)
-}
-
-func IsNullSlice[T any](slice []T) bool {
- nullSlice, ok := nullSlicesCache.Load(shims.TypeFor[T]())
- return ok && reflect.ValueOf(slice).Pointer() == reflect.ValueOf(nullSlice).Pointer()
-}
-
-// internal only
-func IsValueNullPtr(v reflect.Value) bool {
- if v.Kind() != reflect.Ptr {
- return false
+// for internal use only
+func IsValueNull(v reflect.Value) bool {
+ switch v.Kind() {
+ case reflect.Map, reflect.Slice:
+ null, ok := nullCache.Load(v.Type())
+ return ok && v.Pointer() == null.(cacheEntry).ptr
}
- nullptr, ok := nullPtrsCache.Load(v.Type().Elem())
- return ok && v.Pointer() == reflect.ValueOf(nullptr).Pointer()
+ return false
}
-// internal only
-func IsValueNullSlice(v reflect.Value) bool {
- if v.Kind() != reflect.Slice {
- return false
+func IsNull[T any](v T) bool {
+ t := shims.TypeFor[T]()
+ switch t.Kind() {
+ case reflect.Map, reflect.Slice:
+ null, ok := nullCache.Load(t)
+ return ok && reflect.ValueOf(v).Pointer() == null.(cacheEntry).ptr
}
- nullSlice, ok := nullSlicesCache.Load(v.Type().Elem())
- return ok && v.Pointer() == reflect.ValueOf(nullSlice).Pointer()
+ return false
}
@@ -11,21 +11,23 @@ type TimeMarshaler interface {
MarshalJSONWithTimeLayout(string) []byte
}
-var timeType = shims.TypeFor[time.Time]()
+func TimeLayout(fmt string) string {
+ switch fmt {
+ case "", "date-time":
+ return time.RFC3339
+ case "date":
+ return time.DateOnly
+ default:
+ return fmt
+ }
+}
-const DateFmt = "2006-01-02"
+var timeType = shims.TypeFor[time.Time]()
func newTimeEncoder() encoderFunc {
return func(e *encodeState, v reflect.Value, opts encOpts) {
t := v.Interface().(time.Time)
- fmtted := t.Format(opts.timefmt)
- if opts.timefmt == "date" {
- fmtted = t.Format(DateFmt)
- }
- // Default to RFC3339 if format is invalid
- if fmtted == "" {
- fmtted = t.Format(time.RFC3339)
- }
+ fmtted := t.Format(TimeLayout(opts.timefmt))
stringEncoder(e, reflect.ValueOf(fmtted), opts)
}
}
@@ -1,27 +0,0 @@
-package param
-
-import "fmt"
-
-type FieldLike interface{ field() }
-
-// Field is a wrapper used for all values sent to the API,
-// to distinguish zero values from null or omitted fields.
-//
-// It also allows sending arbitrary deserializable values.
-//
-// To instantiate a Field, use the helpers exported from
-// the package root: `F()`, `Null()`, `Raw()`, etc.
-type Field[T any] struct {
- FieldLike
- Value T
- Null bool
- Present bool
- Raw any
-}
-
-func (f Field[T]) String() string {
- if s, ok := any(f.Value).(fmt.Stringer); ok {
- return s.String()
- }
- return fmt.Sprintf("%v", f.Value)
-}
@@ -0,0 +1,30 @@
+package paramutil
+
+import (
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/packages/respjson"
+)
+
+func AddrIfPresent[T comparable](v param.Opt[T]) *T {
+ if v.Valid() {
+ return &v.Value
+ }
+ return nil
+}
+
+func ToOpt[T comparable](v T, meta respjson.Field) param.Opt[T] {
+ if meta.Valid() {
+ return param.NewOpt(v)
+ } else if meta.Raw() == respjson.Null {
+ return param.Null[T]()
+ }
+ return param.Opt[T]{}
+}
+
+// Checks if the value is not omitted and not null
+func Valid(v param.ParamStruct) bool {
+ if ovr, ok := v.Overrides(); ok {
+ return ovr != nil
+ }
+ return !param.IsNull(v) && !param.IsOmitted(v)
+}
@@ -1,11 +1,12 @@
-package param
+package paramutil
import (
"fmt"
+ "github.com/openai/openai-go/packages/param"
"reflect"
)
-var paramUnionType = reflect.TypeOf(APIUnion{})
+var paramUnionType = reflect.TypeOf(param.APIUnion{})
// VariantFromUnion can be used to extract the present variant from a param union type.
// A param union type is a struct with an embedded field of [APIUnion].
@@ -88,7 +88,7 @@ type PreRequestOptionFunc func(*RequestConfig) error
func (s RequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
func (s PreRequestOptionFunc) Apply(r *RequestConfig) error { return s(r) }
-func NewRequestConfig(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) (*RequestConfig, error) {
+func NewRequestConfig(ctx context.Context, method string, u string, body any, dst any, opts ...RequestOption) (*RequestConfig, error) {
var reader io.Reader
contentType := "application/json"
@@ -116,7 +116,11 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa
}
if body, ok := body.(apiquery.Queryer); ok {
hasSerializationFunc = true
- params := body.URLQuery().Encode()
+ q, err := body.URLQuery()
+ if err != nil {
+ return nil, err
+ }
+ params := q.Encode()
if params != "" {
u = u + "?" + params
}
@@ -185,6 +189,12 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa
return &cfg, nil
}
+// This interface is primarily used to describe an [*http.Client], but also
+// supports custom HTTP implementations.
+type HTTPDoer interface {
+ Do(req *http.Request) (*http.Response, error)
+}
+
// RequestConfig represents all the state related to one request.
//
// Editing the variables inside RequestConfig directly is unstable api. Prefer
@@ -195,15 +205,20 @@ type RequestConfig struct {
Context context.Context
Request *http.Request
BaseURL *url.URL
+ // DefaultBaseURL will be used if BaseURL is not explicitly overridden using
+ // WithBaseURL.
+ DefaultBaseURL *url.URL
+ CustomHTTPDoer HTTPDoer
HTTPClient *http.Client
Middlewares []middleware
APIKey string
Organization string
Project string
+ WebhookSecret string
// If ResponseBodyInto not nil, then we will attempt to deserialize into
// ResponseBodyInto. If Destination is a []byte, then it will return the body as
// is.
- ResponseBodyInto interface{}
+ ResponseBodyInto any
// ResponseInto copies the \*http.Response of the corresponding request into the
// given address
ResponseInto **http.Response
@@ -236,7 +251,7 @@ func shouldRetry(req *http.Request, res *http.Response) bool {
return true
}
- // If the header explictly wants a retry behavior, respect that over the
+ // If the header explicitly wants a retry behavior, respect that over the
// http status code.
if res.Header.Get("x-should-retry") == "true" {
return true
@@ -362,7 +377,11 @@ func retryDelay(res *http.Response, retryCount int) time.Duration {
func (cfg *RequestConfig) Execute() (err error) {
if cfg.BaseURL == nil {
- return fmt.Errorf("requestconfig: base url is not set")
+ if cfg.DefaultBaseURL != nil {
+ cfg.BaseURL = cfg.DefaultBaseURL
+ } else {
+ return fmt.Errorf("requestconfig: base url is not set")
+ }
}
cfg.Request.URL, err = cfg.BaseURL.Parse(strings.TrimLeft(cfg.Request.URL.String(), "/"))
@@ -394,6 +413,9 @@ func (cfg *RequestConfig) Execute() (err error) {
}
handler := cfg.HTTPClient.Do
+ if cfg.CustomHTTPDoer != nil {
+ handler = cfg.CustomHTTPDoer.Do
+ }
for i := len(cfg.Middlewares) - 1; i >= 0; i -= 1 {
handler = applyMiddleware(cfg.Middlewares[i], handler)
}
@@ -494,6 +516,7 @@ func (cfg *RequestConfig) Execute() (err error) {
}
contents, err := io.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
return fmt.Errorf("error reading response body: %w", err)
}
@@ -517,21 +540,21 @@ func (cfg *RequestConfig) Execute() (err error) {
return nil
}
- // If the response happens to be a byte array, deserialize the body as-is.
switch dst := cfg.ResponseBodyInto.(type) {
+ // If the response happens to be a byte array, deserialize the body as-is.
case *[]byte:
*dst = contents
- }
-
- err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto)
- if err != nil {
- return fmt.Errorf("error parsing response json: %w", err)
+ default:
+ err = json.NewDecoder(bytes.NewReader(contents)).Decode(cfg.ResponseBodyInto)
+ if err != nil {
+ return fmt.Errorf("error parsing response json: %w", err)
+ }
}
return nil
}
-func ExecuteNewRequest(ctx context.Context, method string, u string, body interface{}, dst interface{}, opts ...RequestOption) error {
+func ExecuteNewRequest(ctx context.Context, method string, u string, body any, dst any, opts ...RequestOption) error {
cfg, err := NewRequestConfig(ctx, method, u, body, dst, opts...)
if err != nil {
return err
@@ -562,6 +585,7 @@ func (cfg *RequestConfig) Clone(ctx context.Context) *RequestConfig {
APIKey: cfg.APIKey,
Organization: cfg.Organization,
Project: cfg.Project,
+ WebhookSecret: cfg.WebhookSecret,
}
return new
@@ -577,17 +601,35 @@ func (cfg *RequestConfig) Apply(opts ...RequestOption) error {
return nil
}
+// PreRequestOptions is used to collect all the options which need to be known before
+// a call to [RequestConfig.ExecuteNewRequest], such as path parameters
+// or global defaults.
+// PreRequestOptions will return a [RequestConfig] with the options applied.
+//
+// Only request option functions of type [PreRequestOptionFunc] are applied.
func PreRequestOptions(opts ...RequestOption) (RequestConfig, error) {
cfg := RequestConfig{}
for _, opt := range opts {
- if _, ok := opt.(PreRequestOptionFunc); !ok {
- continue
+ if opt, ok := opt.(PreRequestOptionFunc); ok {
+ err := opt.Apply(&cfg)
+ if err != nil {
+ return cfg, err
+ }
}
+ }
+ return cfg, nil
+}
- err := opt.Apply(&cfg)
+// WithDefaultBaseURL returns a RequestOption that sets the client's default Base URL.
+// This is always overridden by setting a base URL with WithBaseURL.
+// WithBaseURL should be used instead of WithDefaultBaseURL except in internal code.
+func WithDefaultBaseURL(baseURL string) RequestOption {
+ u, err := url.Parse(baseURL)
+ return RequestOptionFunc(func(r *RequestConfig) error {
if err != nil {
- return cfg, err
+ return err
}
- }
- return cfg, nil
+ r.DefaultBaseURL = u
+ return nil
+ })
}
@@ -2,4 +2,4 @@
package internal
-const PackageVersion = "0.1.0-beta.2" // x-release-please-version
+const PackageVersion = "1.8.2" // x-release-please-version
@@ -12,7 +12,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -96,14 +96,13 @@ type Model struct {
Object constant.Model `json:"object,required"`
// The organization that owns the model.
OwnedBy string `json:"owned_by,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Created resp.Field
- Object resp.Field
- OwnedBy resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Created respjson.Field
+ Object respjson.Field
+ OwnedBy respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -118,13 +117,12 @@ type ModelDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object string `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -5,15 +5,13 @@ package openai
import (
"context"
"net/http"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// ModerationService contains methods and other services that help with interacting
@@ -53,14 +51,13 @@ type Moderation struct {
CategoryScores ModerationCategoryScores `json:"category_scores,required"`
// Whether any of the below categories are flagged.
Flagged bool `json:"flagged,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Categories resp.Field
- CategoryAppliedInputTypes resp.Field
- CategoryScores resp.Field
- Flagged resp.Field
- ExtraFields map[string]resp.Field
+ Categories respjson.Field
+ CategoryAppliedInputTypes respjson.Field
+ CategoryScores respjson.Field
+ Flagged respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -116,23 +113,22 @@ type ModerationCategories struct {
Violence bool `json:"violence,required"`
// Content that depicts death, violence, or physical injury in graphic detail.
ViolenceGraphic bool `json:"violence/graphic,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Harassment resp.Field
- HarassmentThreatening resp.Field
- Hate resp.Field
- HateThreatening resp.Field
- Illicit resp.Field
- IllicitViolent resp.Field
- SelfHarm resp.Field
- SelfHarmInstructions resp.Field
- SelfHarmIntent resp.Field
- Sexual resp.Field
- SexualMinors resp.Field
- Violence resp.Field
- ViolenceGraphic resp.Field
- ExtraFields map[string]resp.Field
+ Harassment respjson.Field
+ HarassmentThreatening respjson.Field
+ Hate respjson.Field
+ HateThreatening respjson.Field
+ Illicit respjson.Field
+ IllicitViolent respjson.Field
+ SelfHarm respjson.Field
+ SelfHarmInstructions respjson.Field
+ SelfHarmIntent respjson.Field
+ Sexual respjson.Field
+ SexualMinors respjson.Field
+ Violence respjson.Field
+ ViolenceGraphic respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -146,48 +142,73 @@ func (r *ModerationCategories) UnmarshalJSON(data []byte) error {
// A list of the categories along with the input type(s) that the score applies to.
type ModerationCategoryAppliedInputTypes struct {
// The applied input type(s) for the category 'harassment'.
+ //
+ // Any of "text".
Harassment []string `json:"harassment,required"`
// The applied input type(s) for the category 'harassment/threatening'.
+ //
+ // Any of "text".
HarassmentThreatening []string `json:"harassment/threatening,required"`
// The applied input type(s) for the category 'hate'.
+ //
+ // Any of "text".
Hate []string `json:"hate,required"`
// The applied input type(s) for the category 'hate/threatening'.
+ //
+ // Any of "text".
HateThreatening []string `json:"hate/threatening,required"`
// The applied input type(s) for the category 'illicit'.
+ //
+ // Any of "text".
Illicit []string `json:"illicit,required"`
// The applied input type(s) for the category 'illicit/violent'.
+ //
+ // Any of "text".
IllicitViolent []string `json:"illicit/violent,required"`
// The applied input type(s) for the category 'self-harm'.
+ //
+ // Any of "text", "image".
SelfHarm []string `json:"self-harm,required"`
// The applied input type(s) for the category 'self-harm/instructions'.
+ //
+ // Any of "text", "image".
SelfHarmInstructions []string `json:"self-harm/instructions,required"`
// The applied input type(s) for the category 'self-harm/intent'.
+ //
+ // Any of "text", "image".
SelfHarmIntent []string `json:"self-harm/intent,required"`
// The applied input type(s) for the category 'sexual'.
+ //
+ // Any of "text", "image".
Sexual []string `json:"sexual,required"`
// The applied input type(s) for the category 'sexual/minors'.
+ //
+ // Any of "text".
SexualMinors []string `json:"sexual/minors,required"`
// The applied input type(s) for the category 'violence'.
+ //
+ // Any of "text", "image".
Violence []string `json:"violence,required"`
// The applied input type(s) for the category 'violence/graphic'.
+ //
+ // Any of "text", "image".
ViolenceGraphic []string `json:"violence/graphic,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Harassment resp.Field
- HarassmentThreatening resp.Field
- Hate resp.Field
- HateThreatening resp.Field
- Illicit resp.Field
- IllicitViolent resp.Field
- SelfHarm resp.Field
- SelfHarmInstructions resp.Field
- SelfHarmIntent resp.Field
- Sexual resp.Field
- SexualMinors resp.Field
- Violence resp.Field
- ViolenceGraphic resp.Field
- ExtraFields map[string]resp.Field
+ Harassment respjson.Field
+ HarassmentThreatening respjson.Field
+ Hate respjson.Field
+ HateThreatening respjson.Field
+ Illicit respjson.Field
+ IllicitViolent respjson.Field
+ SelfHarm respjson.Field
+ SelfHarmInstructions respjson.Field
+ SelfHarmIntent respjson.Field
+ Sexual respjson.Field
+ SexualMinors respjson.Field
+ Violence respjson.Field
+ ViolenceGraphic respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -226,23 +247,22 @@ type ModerationCategoryScores struct {
Violence float64 `json:"violence,required"`
// The score for the category 'violence/graphic'.
ViolenceGraphic float64 `json:"violence/graphic,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Harassment resp.Field
- HarassmentThreatening resp.Field
- Hate resp.Field
- HateThreatening resp.Field
- Illicit resp.Field
- IllicitViolent resp.Field
- SelfHarm resp.Field
- SelfHarmInstructions resp.Field
- SelfHarmIntent resp.Field
- Sexual resp.Field
- SexualMinors resp.Field
- Violence resp.Field
- ViolenceGraphic resp.Field
- ExtraFields map[string]resp.Field
+ Harassment respjson.Field
+ HarassmentThreatening respjson.Field
+ Hate respjson.Field
+ HateThreatening respjson.Field
+ Illicit respjson.Field
+ IllicitViolent respjson.Field
+ SelfHarm respjson.Field
+ SelfHarmInstructions respjson.Field
+ SelfHarmIntent respjson.Field
+ Sexual respjson.Field
+ SexualMinors respjson.Field
+ Violence respjson.Field
+ ViolenceGraphic respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -266,13 +286,13 @@ type ModerationImageURLInputParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ModerationImageURLInputParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ModerationImageURLInputParam) MarshalJSON() (data []byte, err error) {
type shadow ModerationImageURLInputParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ModerationImageURLInputParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Contains either an image URL or a data URL for a base64 encoded image.
//
@@ -283,15 +303,13 @@ type ModerationImageURLInputImageURLParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ModerationImageURLInputImageURLParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ModerationImageURLInputImageURLParam) MarshalJSON() (data []byte, err error) {
type shadow ModerationImageURLInputImageURLParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ModerationImageURLInputImageURLParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ModerationModel = string
@@ -323,13 +341,11 @@ type ModerationMultiModalInputUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ModerationMultiModalInputUnionParam) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u ModerationMultiModalInputUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ModerationMultiModalInputUnionParam](u.OfImageURL, u.OfText)
+ return param.MarshalUnion(u, u.OfImageURL, u.OfText)
+}
+func (u *ModerationMultiModalInputUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ModerationMultiModalInputUnionParam) asAny() any {
@@ -370,16 +386,8 @@ func (u ModerationMultiModalInputUnionParam) GetType() *string {
func init() {
apijson.RegisterUnion[ModerationMultiModalInputUnionParam](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ModerationImageURLInputParam{}),
- DiscriminatorValue: "image_url",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(ModerationTextInputParam{}),
- DiscriminatorValue: "text",
- },
+ apijson.Discriminator[ModerationImageURLInputParam]("image_url"),
+ apijson.Discriminator[ModerationTextInputParam]("text"),
)
}
@@ -396,13 +404,13 @@ type ModerationTextInputParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ModerationTextInputParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ModerationTextInputParam) MarshalJSON() (data []byte, err error) {
type shadow ModerationTextInputParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ModerationTextInputParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Represents if a given text input is potentially harmful.
type ModerationNewResponse struct {
@@ -412,13 +420,12 @@ type ModerationNewResponse struct {
Model string `json:"model,required"`
// A list of moderation objects.
Results []Moderation `json:"results,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Model resp.Field
- Results resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Model respjson.Field
+ Results respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -441,37 +448,36 @@ type ModerationNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ModerationNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r ModerationNewParams) MarshalJSON() (data []byte, err error) {
type shadow ModerationNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ModerationNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
// Use [param.IsOmitted] to confirm if a field is set.
type ModerationNewParamsInputUnion struct {
OfString param.Opt[string] `json:",omitzero,inline"`
- OfModerationNewsInputArray []string `json:",omitzero,inline"`
+ OfStringArray []string `json:",omitzero,inline"`
OfModerationMultiModalArray []ModerationMultiModalInputUnionParam `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ModerationNewParamsInputUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u ModerationNewParamsInputUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ModerationNewParamsInputUnion](u.OfString, u.OfModerationNewsInputArray, u.OfModerationMultiModalArray)
+ return param.MarshalUnion(u, u.OfString, u.OfStringArray, u.OfModerationMultiModalArray)
+}
+func (u *ModerationNewParamsInputUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ModerationNewParamsInputUnion) asAny() any {
if !param.IsOmitted(u.OfString) {
return &u.OfString.Value
- } else if !param.IsOmitted(u.OfModerationNewsInputArray) {
- return &u.OfModerationNewsInputArray
+ } else if !param.IsOmitted(u.OfStringArray) {
+ return &u.OfStringArray
} else if !param.IsOmitted(u.OfModerationMultiModalArray) {
return &u.OfModerationMultiModalArray
}
@@ -0,0 +1,38 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package option
+
+import (
+ "log"
+ "net/http"
+ "net/http/httputil"
+)
+
+// WithDebugLog logs the HTTP request and response content.
+// If the logger parameter is nil, it uses the default logger.
+//
+// WithDebugLog is for debugging and development purposes only.
+// It should not be used in production code. The behavior and interface
+// of WithDebugLog is not guaranteed to be stable.
+func WithDebugLog(logger *log.Logger) RequestOption {
+ return WithMiddleware(func(req *http.Request, nxt MiddlewareNext) (*http.Response, error) {
+ if logger == nil {
+ logger = log.Default()
+ }
+
+ if reqBytes, err := httputil.DumpRequest(req, true); err == nil {
+ logger.Printf("Request Content:\n%s\n", reqBytes)
+ }
+
+ resp, err := nxt(req)
+ if err != nil {
+ return resp, err
+ }
+
+ if respBytes, err := httputil.DumpResponse(resp, true); err == nil {
+ logger.Printf("Response Content:\n%s\n", respBytes)
+ }
+
+ return resp, err
+ })
+}
@@ -6,7 +6,6 @@ import (
"bytes"
"fmt"
"io"
- "log"
"net/http"
"net/url"
"strings"
@@ -24,12 +23,15 @@ import (
type RequestOption = requestconfig.RequestOption
// WithBaseURL returns a RequestOption that sets the BaseURL for the client.
+//
+// For security reasons, ensure that the base URL is trusted.
func WithBaseURL(base string) RequestOption {
u, err := url.Parse(base)
- if err != nil {
- log.Fatalf("failed to parse BaseURL: %s\n", err)
- }
return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error {
+ if err != nil {
+ return fmt.Errorf("requestoption: WithBaseURL failed to parse url %s\n", err)
+ }
+
if u.Path != "" && !strings.HasSuffix(u.Path, "/") {
u.Path += "/"
}
@@ -38,11 +40,34 @@ func WithBaseURL(base string) RequestOption {
})
}
-// WithHTTPClient returns a RequestOption that changes the underlying [http.Client] used to make this
+// HTTPClient is primarily used to describe an [*http.Client], but also
+// supports custom implementations.
+//
+// For bespoke implementations, prefer using an [*http.Client] with a
+// custom transport. See [http.RoundTripper] for further information.
+type HTTPClient interface {
+ Do(*http.Request) (*http.Response, error)
+}
+
+// WithHTTPClient returns a RequestOption that changes the underlying http client used to make this
// request, which by default is [http.DefaultClient].
-func WithHTTPClient(client *http.Client) RequestOption {
+//
+// For custom uses cases, it is recommended to provide an [*http.Client] with a custom
+// [http.RoundTripper] as its transport, rather than directly implementing [HTTPClient].
+func WithHTTPClient(client HTTPClient) RequestOption {
return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) error {
- r.HTTPClient = client
+ if client == nil {
+ return fmt.Errorf("requestoption: custom http client cannot be nil")
+ }
+
+ if c, ok := client.(*http.Client); ok {
+ // Prefer the native client if possible.
+ r.HTTPClient = c
+ r.CustomHTTPDoer = nil
+ } else {
+ r.CustomHTTPDoer = client
+ }
+
return nil
})
}
@@ -142,19 +167,27 @@ func WithQueryDel(key string) RequestOption {
// The key accepts a string as defined by the [sjson format].
//
// [sjson format]: https://github.com/tidwall/sjson
-func WithJSONSet(key string, value interface{}) RequestOption {
+func WithJSONSet(key string, value any) RequestOption {
return requestconfig.RequestOptionFunc(func(r *requestconfig.RequestConfig) (err error) {
- if buffer, ok := r.Body.(*bytes.Buffer); ok {
- b := buffer.Bytes()
+ var b []byte
+
+ if r.Body == nil {
+ b, err = sjson.SetBytes(nil, key, value)
+ if err != nil {
+ return err
+ }
+ } else if buffer, ok := r.Body.(*bytes.Buffer); ok {
+ b = buffer.Bytes()
b, err = sjson.SetBytes(b, key, value)
if err != nil {
return err
}
- r.Body = bytes.NewBuffer(b)
- return nil
+ } else {
+ return fmt.Errorf("cannot use WithJSONSet on a body that is not serialized as *bytes.Buffer")
}
- return fmt.Errorf("cannot use WithJSONSet on a body that is not serialized as *bytes.Buffer")
+ r.Body = bytes.NewBuffer(b)
+ return nil
})
}
@@ -229,7 +262,7 @@ func WithRequestTimeout(dur time.Duration) RequestOption {
// environment to be the "production" environment. An environment specifies which base URL
// to use by default.
func WithEnvironmentProduction() RequestOption {
- return WithBaseURL("https://api.openai.com/v1/")
+ return requestconfig.WithDefaultBaseURL("https://api.openai.com/v1/")
}
// WithAPIKey returns a RequestOption that sets the client setting "api_key".
@@ -255,3 +288,11 @@ func WithProject(value string) RequestOption {
return r.Apply(WithHeader("OpenAI-Project", value))
})
}
+
+// WithWebhookSecret returns a RequestOption that sets the client setting "webhook_secret".
+func WithWebhookSecret(value string) requestconfig.PreRequestOptionFunc {
+ return requestconfig.PreRequestOptionFunc(func(r *requestconfig.RequestConfig) error {
+ r.WebhookSecret = value
+ return nil
+ })
+}
@@ -10,7 +10,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
)
// aliased to make [param.APIUnion] private when embedding
@@ -22,12 +22,11 @@ type paramObj = param.APIObject
type Page[T any] struct {
Data []T `json:"data"`
Object string `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
cfg *requestconfig.RequestConfig
@@ -44,6 +43,9 @@ func (r *Page[T]) UnmarshalJSON(data []byte) error {
// there is no next page, this function will return a 'nil' for the page value, but
// will not return an error
func (r *Page[T]) GetNextPage() (res *Page[T], err error) {
+ if len(r.Data) == 0 {
+ return nil, nil
+ }
// This page represents a response that isn't actually paginated at the API level
// so there will never be a next page.
cfg := (*requestconfig.RequestConfig)(nil)
@@ -117,12 +119,11 @@ func (r *PageAutoPager[T]) Index() int {
type CursorPage[T any] struct {
Data []T `json:"data"`
HasMore bool `json:"has_more"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- HasMore resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ HasMore respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
cfg *requestconfig.RequestConfig
@@ -139,7 +140,11 @@ func (r *CursorPage[T]) UnmarshalJSON(data []byte) error {
// there is no next page, this function will return a 'nil' for the page value, but
// will not return an error
func (r *CursorPage[T]) GetNextPage() (res *CursorPage[T], err error) {
- if r.JSON.HasMore.IsPresent() && r.HasMore == false {
+ if len(r.Data) == 0 {
+ return nil, nil
+ }
+
+ if r.JSON.HasMore.Valid() && r.HasMore == false {
return nil, nil
}
items := r.Data
@@ -149,7 +154,10 @@ func (r *CursorPage[T]) GetNextPage() (res *CursorPage[T], err error) {
cfg := r.cfg.Clone(r.cfg.Context)
value := reflect.ValueOf(items[len(items)-1])
field := value.FieldByName("ID")
- cfg.Apply(option.WithQuery("after", field.Interface().(string)))
+ err = cfg.Apply(option.WithQuery("after", field.Interface().(string)))
+ if err != nil {
+ return nil, err
+ }
var raw *http.Response
cfg.ResponseInto = &raw
cfg.ResponseBodyInto = &res
@@ -4,8 +4,6 @@ import (
"encoding/json"
"fmt"
"reflect"
- "strings"
- "sync"
"time"
shimjson "github.com/openai/openai-go/internal/encoding/json"
@@ -13,9 +11,11 @@ import (
"github.com/tidwall/sjson"
)
-// This type will not be stable and shouldn't be relied upon
+// EncodedAsDate is not be stable and shouldn't be relied upon
type EncodedAsDate Opt[time.Time]
+type forceOmit int
+
func (m EncodedAsDate) MarshalJSON() ([]byte, error) {
underlying := Opt[time.Time](m)
bytes := underlying.MarshalJSONWithTimeLayout("2006-01-02")
@@ -25,31 +25,50 @@ func (m EncodedAsDate) MarshalJSON() ([]byte, error) {
return underlying.MarshalJSON()
}
-// This uses a shimmed 'encoding/json' from Go 1.24, to support the 'omitzero' tag
-func MarshalObject[T OverridableObject](f T, underlying any) ([]byte, error) {
- if f.IsNull() {
+// MarshalObject uses a shimmed 'encoding/json' from Go 1.24, to support the 'omitzero' tag
+//
+// Stability for the API of MarshalObject is not guaranteed.
+func MarshalObject[T ParamStruct](f T, underlying any) ([]byte, error) {
+ return MarshalWithExtras(f, underlying, f.extraFields())
+}
+
+// MarshalWithExtras is used to marshal a struct with additional properties.
+//
+// Stability for the API of MarshalWithExtras is not guaranteed.
+func MarshalWithExtras[T ParamStruct, R any](f T, underlying any, extras map[string]R) ([]byte, error) {
+ if f.null() {
return []byte("null"), nil
- } else if extras := f.GetExtraFields(); extras != nil {
+ } else if len(extras) > 0 {
bytes, err := shimjson.Marshal(underlying)
if err != nil {
return nil, err
}
for k, v := range extras {
+ var a any = v
+ if a == Omit {
+ // Errors when handling ForceOmitted are ignored.
+ if b, e := sjson.DeleteBytes(bytes, k); e == nil {
+ bytes = b
+ }
+ continue
+ }
bytes, err = sjson.SetBytes(bytes, k, v)
if err != nil {
return nil, err
}
}
return bytes, nil
- } else if ovr, ok := f.IsOverridden(); ok {
+ } else if ovr, ok := f.Overrides(); ok {
return shimjson.Marshal(ovr)
} else {
return shimjson.Marshal(underlying)
}
}
-// This uses a shimmed 'encoding/json' from Go 1.24, to support the 'omitzero' tag
-func MarshalUnion[T any](variants ...any) ([]byte, error) {
+// MarshalUnion uses a shimmed 'encoding/json' from Go 1.24, to support the 'omitzero' tag
+//
+// Stability for the API of MarshalUnion is not guaranteed.
+func MarshalUnion[T ParamStruct](metadata T, variants ...any) ([]byte, error) {
nPresent := 0
presentIdx := -1
for i, variant := range variants {
@@ -59,6 +78,9 @@ func MarshalUnion[T any](variants ...any) ([]byte, error) {
}
}
if nPresent == 0 || presentIdx == -1 {
+ if ovr, ok := metadata.Overrides(); ok {
+ return shimjson.Marshal(ovr)
+ }
return []byte(`null`), nil
} else if nPresent > 1 {
return nil, &json.MarshalerError{
@@ -69,7 +91,7 @@ func MarshalUnion[T any](variants ...any) ([]byte, error) {
return shimjson.Marshal(variants[presentIdx])
}
-// shimmed from Go 1.23 "reflect" package
+// typeFor is shimmed from Go 1.23 "reflect" package
func typeFor[T any]() reflect.Type {
var v T
if t := reflect.TypeOf(v); t != nil {
@@ -77,49 +99,3 @@ func typeFor[T any]() reflect.Type {
}
return reflect.TypeOf((*T)(nil)).Elem() // only for an interface kind
}
-
-var optStringType = typeFor[Opt[string]]()
-var optIntType = typeFor[Opt[int64]]()
-var optFloatType = typeFor[Opt[float64]]()
-var optBoolType = typeFor[Opt[bool]]()
-
-var OptionalPrimitiveTypes map[reflect.Type][]int
-
-// indexOfUnderlyingValueField must only be called at initialization time
-func indexOfUnderlyingValueField(t reflect.Type) []int {
- field, ok := t.FieldByName("Value")
- if !ok {
- panic("unreachable: initialization issue, underlying value field not found")
- }
- return field.Index
-}
-
-func init() {
- OptionalPrimitiveTypes = map[reflect.Type][]int{
- optStringType: indexOfUnderlyingValueField(optStringType),
- optIntType: indexOfUnderlyingValueField(optIntType),
- optFloatType: indexOfUnderlyingValueField(optFloatType),
- optBoolType: indexOfUnderlyingValueField(optBoolType),
- }
-}
-
-var structFieldsCache sync.Map
-
-func structFields(t reflect.Type) (map[string][]int, error) {
- if cached, ok := structFieldsCache.Load(t); ok {
- return cached.(map[string][]int), nil
- }
- if t.Kind() != reflect.Struct {
- return nil, fmt.Errorf("resp: expected struct but got %v of kind %v", t.String(), t.Kind().String())
- }
- structFields := map[string][]int{}
- for i := 0; i < t.NumField(); i++ {
- field := t.Field(i)
- name := strings.Split(field.Tag.Get("json"), ",")[0]
- if name == "" || name == "-" || field.Anonymous {
- continue
- }
- structFields[name] = field.Index
- }
- return structFields, nil
-}
@@ -0,0 +1,19 @@
+package param
+
+import "github.com/openai/openai-go/internal/encoding/json/sentinel"
+
+// NullMap returns a non-nil map with a length of 0.
+// When used with [MarshalObject] or [MarshalUnion], it will be marshaled as null.
+//
+// It is unspecified behavior to mutate the map returned by [NullMap].
+func NullMap[MapT ~map[string]T, T any]() MapT {
+ return sentinel.NewNullSentinel(func() MapT { return make(MapT, 1) })
+}
+
+// NullSlice returns a non-nil slice with a length of 0.
+// When used with [MarshalObject] or [MarshalUnion], it will be marshaled as null.
+//
+// It is unspecified behavior to mutate the slice returned by [NullSlice].
+func NullSlice[SliceT ~[]T, T any]() SliceT {
+ return sentinel.NewNullSentinel(func() SliceT { return make(SliceT, 0, 1) })
+}
@@ -2,62 +2,64 @@ package param
import (
"encoding/json"
- "reflect"
+ "fmt"
+ shimjson "github.com/openai/openai-go/internal/encoding/json"
"time"
)
func NewOpt[T comparable](v T) Opt[T] {
- return Opt[T]{Value: v, Status: included}
+ return Opt[T]{Value: v, status: included}
}
-// Sets an optional field to null, to set an object to null use [NullObj].
-func NullOpt[T comparable]() Opt[T] { return Opt[T]{Status: null} }
-
-type Opt[T comparable] struct {
- Value T
- // indicates whether the field should be omitted, null, or valid
- Status Status
-}
+// Null creates optional field with the JSON value "null".
+//
+// To set a struct to null, use [NullStruct].
+func Null[T comparable]() Opt[T] { return Opt[T]{status: null} }
-type Status int8
+type status int8
const (
- omitted Status = iota
+ omitted status = iota
null
included
)
-type Optional interface {
- // IsPresent returns true if the value is not "null" or omitted
- IsPresent() bool
-
- // IsOmitted returns true if the value is omitted, it returns false if the value is "null".
- IsOmitted() bool
-
- // IsNull returns true if the value is "null", it returns false if the value is omitted.
- IsNull() bool
+// Opt represents an optional parameter of type T. Use
+// the [Opt.Valid] method to confirm.
+type Opt[T comparable] struct {
+ Value T
+ // indicates whether the field should be omitted, null, or valid
+ status status
+ opt
}
-// IsPresent returns true if the value is not "null" and not omitted
-func (o Opt[T]) IsPresent() bool {
+// Valid returns true if the value is not "null" or omitted.
+//
+// To check if explicitly null, use [Opt.Null].
+func (o Opt[T]) Valid() bool {
var empty Opt[T]
- return o.Status == included || o != empty && o.Status != null
+ return o.status == included || o != empty && o.status != null
}
-// IsNull returns true if the value is specifically the JSON value "null".
-// It returns false if the value is omitted.
-//
-// Prefer to use [IsPresent] to check the presence of a value.
-func (o Opt[T]) IsNull() bool { return o.Status == null }
+func (o Opt[T]) Or(v T) T {
+ if o.Valid() {
+ return o.Value
+ }
+ return v
+}
-// IsOmitted returns true if the value is omitted.
-// It returns false if the value is the JSON value "null".
-//
-// Prefer to use [IsPresent] to check the presence of a value.
-func (o Opt[T]) IsOmitted() bool { return o == Opt[T]{} }
+func (o Opt[T]) String() string {
+ if o.null() {
+ return "null"
+ }
+ if s, ok := any(o.Value).(fmt.Stringer); ok {
+ return s.String()
+ }
+ return fmt.Sprintf("%v", o.Value)
+}
func (o Opt[T]) MarshalJSON() ([]byte, error) {
- if !o.IsPresent() {
+ if !o.Valid() {
return []byte("null"), nil
}
return json.Marshal(o.Value)
@@ -65,39 +67,55 @@ func (o Opt[T]) MarshalJSON() ([]byte, error) {
func (o *Opt[T]) UnmarshalJSON(data []byte) error {
if string(data) == "null" {
- o.Status = null
+ o.status = null
return nil
}
- return json.Unmarshal(data, &o.Value)
-}
-func (o Opt[T]) Or(v T) T {
- if o.IsPresent() {
- return o.Value
+ var value *T
+ if err := json.Unmarshal(data, &value); err != nil {
+ return err
}
- return v
-}
-// This is a sketchy way to implement time Formatting
-var timeType = reflect.TypeOf(time.Time{})
-var timeTimeValueLoc, _ = reflect.TypeOf(Opt[time.Time]{}).FieldByName("Value")
+ if value == nil {
+ o.status = omitted
+ return nil
+ }
+
+ o.status = included
+ o.Value = *value
+ return nil
+}
-// Don't worry about this function, returns nil to fallback towards [MarshalJSON]
+// MarshalJSONWithTimeLayout is necessary to bypass the internal caching performed
+// by [json.Marshal]. Prefer to use [Opt.MarshalJSON] instead.
+//
+// This function requires that the generic type parameter of [Opt] is not [time.Time].
func (o Opt[T]) MarshalJSONWithTimeLayout(format string) []byte {
t, ok := any(o.Value).(time.Time)
- if !ok || o.IsNull() {
+ if !ok || o.null() {
return nil
}
- if format == "" {
- format = time.RFC3339
- } else if format == "date" {
- format = "2006-01-02"
- }
-
- b, err := json.Marshal(t.Format(format))
+ b, err := json.Marshal(t.Format(shimjson.TimeLayout(format)))
if err != nil {
return nil
}
return b
}
+
+func (o Opt[T]) null() bool { return o.status == null }
+func (o Opt[T]) isZero() bool { return o == Opt[T]{} }
+
+// opt helps limit the [Optional] interface to only types in this package
+type opt struct{}
+
+func (opt) implOpt() {}
+
+// This interface is useful for internal purposes.
+type Optional interface {
+ Valid() bool
+ null() bool
+
+ isZero() bool
+ implOpt()
+}
@@ -2,58 +2,99 @@ package param
import (
"encoding/json"
+ "github.com/openai/openai-go/internal/encoding/json/sentinel"
"reflect"
)
-// NullObj is used to mark a struct as null.
-// To send null to an [Opt] field use [NullOpt].
-func NullObj[T NullableObject, PT Settable[T]]() T {
+// NullStruct is used to set a struct to the JSON value null.
+// Check for null structs with [IsNull].
+//
+// Only the first type parameter should be provided,
+// the type PtrT will be inferred.
+//
+// json.Marshal(param.NullStruct[MyStruct]()) -> 'null'
+//
+// To send null to an [Opt] field use [Null].
+func NullStruct[T ParamStruct, PtrT InferPtr[T]]() T {
var t T
- pt := PT(&t)
+ pt := PtrT(&t)
pt.setMetadata(nil)
return *pt
}
-// To override a specific field in a struct, use its [WithExtraFields] method.
-func OverrideObj[T OverridableObject, PT Settable[T]](v any) T {
+// Override replaces the value of a struct with any type.
+//
+// Only the first type parameter should be provided,
+// the type PtrT will be inferred.
+//
+// It's often useful for providing raw JSON
+//
+// param.Override[MyStruct](json.RawMessage(`{"foo": "bar"}`))
+//
+// The public fields of the returned struct T will be unset.
+//
+// To override a specific field in a struct, use its [SetExtraFields] method.
+func Override[T ParamStruct, PtrT InferPtr[T]](v any) T {
var t T
- pt := PT(&t)
- pt.setMetadata(nil)
+ pt := PtrT(&t)
+ pt.setMetadata(v)
return *pt
}
// IsOmitted returns true if v is the zero value of its type.
//
-// It indicates if a field with the `json:"...,omitzero"` tag will be omitted
-// from serialization.
-//
-// If v is set explicitly to the JSON value "null", this function will return false.
-// Therefore, when available, prefer using the [IsPresent] method to check whether
-// a field is present.
+// If IsOmitted is true, and the field uses a `json:"...,omitzero"` tag,
+// the field will be omitted from the request.
//
-// Generally, this function should only be used on structs, arrays, maps.
+// If v is set explicitly to the JSON value "null", IsOmitted returns false.
func IsOmitted(v any) bool {
if v == nil {
return false
}
- if o, ok := v.(interface{ IsOmitted() bool }); ok {
- return o.IsOmitted()
+ if o, ok := v.(Optional); ok {
+ return o.isZero()
}
return reflect.ValueOf(v).IsZero()
}
-type NullableObject = overridableStruct
-type OverridableObject = overridableStruct
+// IsNull returns true if v was set to the JSON value null.
+//
+// To set a param to null use [NullStruct], [Null], [NullMap], or [NullSlice]
+// depending on the type of v.
+//
+// IsNull returns false if the value is omitted.
+func IsNull[T any](v T) bool {
+ if nullable, ok := any(v).(ParamNullable); ok {
+ return nullable.null()
+ }
-type Settable[T overridableStruct] interface {
- setMetadata(any)
- *T
+ switch reflect.TypeOf(v).Kind() {
+ case reflect.Slice, reflect.Map:
+ return sentinel.IsNull(v)
+ }
+
+ return false
+}
+
+// ParamNullable encapsulates all structs in parameters,
+// and all [Opt] types in parameters.
+type ParamNullable interface {
+ null() bool
}
-type overridableStruct interface {
- IsNull() bool
- IsOverridden() (any, bool)
- GetExtraFields() map[string]any
+// ParamStruct represents the set of all structs that are
+// used in API parameters, by convention these usually end in
+// "Params" or "Param".
+type ParamStruct interface {
+ Overrides() (any, bool)
+ null() bool
+ extraFields() map[string]any
+}
+
+// This is an implementation detail and should never be explicitly set.
+type InferPtr[T ParamStruct] interface {
+ setMetadata(any)
+ *T
}
// APIObject should be embedded in api object fields, preferably using an alias to make private
@@ -62,42 +103,62 @@ type APIObject struct{ metadata }
// APIUnion should be embedded in all api unions fields, preferably using an alias to make private
type APIUnion struct{ metadata }
-type metadata struct{ any }
-type metadataNull struct{}
-type metadataExtraFields map[string]any
-
-// IsNull returns true if the field is the explicit value `null`,
-// prefer using [IsPresent] to check for presence, since it checks against null and omitted.
-func (m metadata) IsNull() bool {
- if _, ok := m.any.(metadataNull); ok {
- return true
- }
-
- if msg, ok := m.any.(json.RawMessage); ok {
- return string(msg) == "null"
- }
-
- return false
-}
-
-func (m metadata) IsOverridden() (any, bool) {
+// Overrides returns the value of the struct when it is created with
+// [Override], the second argument helps differentiate an explicit null.
+func (m metadata) Overrides() (any, bool) {
if _, ok := m.any.(metadataExtraFields); ok {
return nil, false
}
return m.any, m.any != nil
}
-func (m metadata) GetExtraFields() map[string]any {
+// ExtraFields returns the extra fields added to the JSON object.
+func (m metadata) ExtraFields() map[string]any {
if extras, ok := m.any.(metadataExtraFields); ok {
return extras
}
return nil
}
-func (m *metadata) WithExtraFields(fields map[string]any) {
- m.any = metadataExtraFields(fields)
+// Omit can be used with [metadata.SetExtraFields] to ensure that a
+// required field is omitted. This is useful as an escape hatch for
+// when a required is unwanted for some unexpected reason.
+const Omit forceOmit = -1
+
+// SetExtraFields adds extra fields to the JSON object.
+//
+// SetExtraFields will override any existing fields with the same key.
+// For security reasons, ensure this is only used with trusted input data.
+//
+// To intentionally omit a required field, use [Omit].
+//
+// foo.SetExtraFields(map[string]any{"bar": Omit})
+//
+// If the struct already contains the field ExtraFields, then this
+// method will have no effect.
+func (m *metadata) SetExtraFields(extraFields map[string]any) {
+ m.any = metadataExtraFields(extraFields)
+}
+
+// extraFields aliases [metadata.ExtraFields] to avoid name collisions.
+func (m metadata) extraFields() map[string]any { return m.ExtraFields() }
+
+func (m metadata) null() bool {
+ if _, ok := m.any.(metadataNull); ok {
+ return true
+ }
+
+ if msg, ok := m.any.(json.RawMessage); ok {
+ return string(msg) == "null"
+ }
+
+ return false
}
+type metadata struct{ any }
+type metadataNull struct{}
+type metadataExtraFields map[string]any
+
func (m *metadata) setMetadata(override any) {
if override == nil {
m.any = metadataNull{}
@@ -1,31 +0,0 @@
-package param
-
-import (
- "github.com/openai/openai-go/internal/encoding/json/sentinel"
-)
-
-// NullPtr returns a pointer to the zero value of the type T.
-// When used with [MarshalObject] or [MarshalUnion], it will be marshaled as null.
-//
-// It is unspecified behavior to mutate the value pointed to by the returned pointer.
-func NullPtr[T any]() *T {
- return sentinel.NullPtr[T]()
-}
-
-// IsNullPtr returns true if the pointer was created by [NullPtr].
-func IsNullPtr[T any](ptr *T) bool {
- return sentinel.IsNullPtr(ptr)
-}
-
-// NullSlice returns a non-nil slice with a length of 0.
-// When used with [MarshalObject] or [MarshalUnion], it will be marshaled as null.
-//
-// It is undefined behavior to mutate the slice returned by [NullSlice].
-func NullSlice[T any]() []T {
- return sentinel.NullSlice[T]()
-}
-
-// IsNullSlice returns true if the slice was created by [NullSlice].
-func IsNullSlice[T any](slice []T) bool {
- return sentinel.IsNullSlice(slice)
-}
@@ -1,56 +0,0 @@
-package resp
-
-// A Field contains metadata about a JSON field that was
-// unmarshalled from a response.
-//
-// To check if the field was unmarshalled successfully, use the [Field.IsPresent] method.
-//
-// Use the [Field.IsExplicitNull] method to check if the JSON value is "null".
-//
-// If the [Field.Raw] is the empty string, then the field was omitted.
-//
-// Otherwise, if the field was invalid and couldn't be marshalled successfully, [Field.IsPresent] will be false,
-// and [Field.Raw] will not be empty.
-type Field struct {
- status
- raw string
-}
-
-const (
- omitted status = iota
- null
- invalid
- valid
-)
-
-type status int8
-
-// IsPresent returns true if the field was unmarshalled successfully.
-// If IsPresent is false, the field was either omitted, the JSON value "null", or an unexpected type.
-func (j Field) IsPresent() bool { return j.status > invalid }
-
-// Returns true if the field is the JSON value "null".
-func (j Field) IsExplicitNull() bool { return j.status == null }
-
-// Returns the raw JSON value of the field.
-func (j Field) Raw() string {
- if j.status == omitted {
- return ""
- }
- return j.raw
-}
-
-func NewValidField(raw string) Field {
- if raw == "null" {
- return NewNullField()
- }
- return Field{raw: raw, status: valid}
-}
-
-func NewNullField() Field {
- return Field{status: null}
-}
-
-func NewInvalidField(raw string) Field {
- return Field{status: invalid, raw: raw}
-}
@@ -0,0 +1,88 @@
+package respjson
+
+// A Field provides metadata to indicate the presence of a value.
+//
+// Use [Field.Valid] to check if an optional value was null or omitted.
+//
+// A Field will always occur in the following structure, where it
+// mirrors the original field in it's parent struct:
+//
+// type ExampleObject struct {
+// Foo bool `json:"foo"`
+// Bar int `json:"bar"`
+// // ...
+//
+// // JSON provides metadata about the object.
+// JSON struct {
+// Foo Field
+// Bar Field
+// // ...
+// } `json:"-"`
+// }
+//
+// To differentiate a "nullish" value from the zero value,
+// use the [Field.Valid] method.
+//
+// if !example.JSON.Foo.Valid() {
+// println("Foo is null or omitted")
+// }
+//
+// if example.Foo {
+// println("Foo is true")
+// } else {
+// println("Foo is false")
+// }
+//
+// To differentiate if a field was omitted or the JSON value "null",
+// use the [Field.Raw] method.
+//
+// if example.JSON.Foo.Raw() == "null" {
+// println("Foo is null")
+// }
+//
+// if example.JSON.Foo.Raw() == "" {
+// println("Foo was omitted")
+// }
+//
+// Otherwise, if the field was invalid and couldn't be marshalled successfully,
+// [Field.Valid] will be false and [Field.Raw] will not be empty.
+type Field struct {
+ status
+ raw string
+}
+
+const (
+ omitted status = iota
+ null
+ invalid
+ valid
+)
+
+type status int8
+
+// Valid returns true if the parent field was set.
+// Valid returns false if the value doesn't exist, is JSON null, or
+// is an unexpected type.
+func (j Field) Valid() bool { return j.status > invalid }
+
+const Null string = "null"
+const Omitted string = ""
+
+// Returns the raw JSON value of the field.
+func (j Field) Raw() string {
+ if j.status == omitted {
+ return ""
+ }
+ return j.raw
+}
+
+func NewField(raw string) Field {
+ if raw == "null" {
+ return Field{status: null, raw: Null}
+ }
+ return Field{status: valid, raw: raw}
+}
+
+func NewInvalidField(raw string) Field {
+ return Field{status: invalid, raw: raw}
+}
@@ -31,8 +31,9 @@ func NewDecoder(res *http.Response) Decoder {
if t, ok := decoderTypes[contentType]; ok {
decoder = t(res.Body)
} else {
- scanner := bufio.NewScanner(res.Body)
- decoder = &eventStreamDecoder{rc: res.Body, scn: scanner}
+ scn := bufio.NewScanner(res.Body)
+ scn.Buffer(nil, bufio.MaxScanTokenSize<<4)
+ decoder = &eventStreamDecoder{rc: res.Body, scn: scn}
}
return decoder
}
@@ -162,16 +163,18 @@ func (s *Stream[T]) Next() bool {
continue
}
+ var nxt T
if s.decoder.Event().Type == "" || strings.HasPrefix(s.decoder.Event().Type, "response.") {
ep := gjson.GetBytes(s.decoder.Event().Data, "error")
if ep.Exists() {
s.err = fmt.Errorf("received error while streaming: %s", ep.String())
return false
}
- s.err = json.Unmarshal(s.decoder.Event().Data, &s.cur)
+ s.err = json.Unmarshal(s.decoder.Event().Data, &nxt)
if s.err != nil {
return false
}
+ s.cur = nxt
return true
} else {
ep := gjson.GetBytes(s.decoder.Event().Data, "error")
@@ -181,10 +184,11 @@ func (s *Stream[T]) Next() bool {
}
event := s.decoder.Event().Type
data := s.decoder.Event().Data
- s.err = json.Unmarshal([]byte(fmt.Sprintf(`{ "event": %q, "data": %s }`, event, data)), &s.cur)
+ s.err = json.Unmarshal([]byte(fmt.Sprintf(`{ "event": %q, "data": %s }`, event, data)), &nxt)
if s.err != nil {
return false
}
+ s.cur = nxt
return true
}
}
@@ -32,7 +32,7 @@ func (r *VectorStoreFileService) PollStatus(ctx context.Context, vectorStoreID s
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
opts = append(opts, option.WithResponseInto(&raw))
for {
- file, err := r.Get(ctx, vectorStoreID, fileID, opts...)
+ file, err := r.Get(ctx, fileID, vectorStoreID, opts...)
if err != nil {
return nil, fmt.Errorf("vector store file poll: received %w", err)
}
@@ -67,7 +67,7 @@ func (r *VectorStoreFileBatchService) PollStatus(ctx context.Context, vectorStor
opts = append(opts, option.WithResponseInto(&raw))
opts = append(opts, mkPollingOptions(pollIntervalMs)...)
for {
- batch, err := r.Get(ctx, vectorStoreID, batchID, opts...)
+ batch, err := r.Get(ctx, batchID, vectorStoreID, opts...)
if err != nil {
return nil, fmt.Errorf("vector store file batch poll: received %w", err)
}
@@ -93,42 +93,3 @@ func (r *VectorStoreFileBatchService) PollStatus(ctx context.Context, vectorStor
}
}
}
-
-// PollStatus waits until a Run is no longer in an incomplete state and returns it.
-// Pass 0 as pollIntervalMs to use the default polling interval of 1 second.
-func (r *BetaThreadRunService) PollStatus(ctx context.Context, threadID string, runID string, pollIntervalMs int, opts ...option.RequestOption) (res *Run, err error) {
- var raw *http.Response
- opts = append(opts, mkPollingOptions(pollIntervalMs)...)
- opts = append(opts, option.WithResponseInto(&raw))
- for {
- run, err := r.Get(ctx, threadID, runID, opts...)
- if err != nil {
- return nil, fmt.Errorf("thread run poll: received %w", err)
- }
-
- switch run.Status {
- case RunStatusInProgress,
- RunStatusQueued:
- if pollIntervalMs <= 0 {
- pollIntervalMs = getPollInterval(raw)
- }
- time.Sleep(time.Duration(pollIntervalMs) * time.Millisecond)
- case RunStatusRequiresAction,
- RunStatusCancelled,
- RunStatusCompleted,
- RunStatusFailed,
- RunStatusExpired,
- RunStatusIncomplete:
- return run, nil
- default:
- return nil, fmt.Errorf("invalid thread run status during polling: received %s", run.Status)
- }
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- default:
- break
- }
- }
-}
@@ -5,7 +5,6 @@ package responses
import (
"github.com/openai/openai-go/internal/apierror"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
"github.com/openai/openai-go/shared"
)
@@ -20,6 +19,36 @@ type Error = apierror.Error
// This is an alias to an internal type.
type ChatModel = shared.ChatModel
+// Equals "gpt-4.1"
+const ChatModelGPT4_1 = shared.ChatModelGPT4_1
+
+// Equals "gpt-4.1-mini"
+const ChatModelGPT4_1Mini = shared.ChatModelGPT4_1Mini
+
+// Equals "gpt-4.1-nano"
+const ChatModelGPT4_1Nano = shared.ChatModelGPT4_1Nano
+
+// Equals "gpt-4.1-2025-04-14"
+const ChatModelGPT4_1_2025_04_14 = shared.ChatModelGPT4_1_2025_04_14
+
+// Equals "gpt-4.1-mini-2025-04-14"
+const ChatModelGPT4_1Mini2025_04_14 = shared.ChatModelGPT4_1Mini2025_04_14
+
+// Equals "gpt-4.1-nano-2025-04-14"
+const ChatModelGPT4_1Nano2025_04_14 = shared.ChatModelGPT4_1Nano2025_04_14
+
+// Equals "o4-mini"
+const ChatModelO4Mini = shared.ChatModelO4Mini
+
+// Equals "o4-mini-2025-04-16"
+const ChatModelO4Mini2025_04_16 = shared.ChatModelO4Mini2025_04_16
+
+// Equals "o3"
+const ChatModelO3 = shared.ChatModelO3
+
+// Equals "o3-2025-04-16"
+const ChatModelO3_2025_04_16 = shared.ChatModelO3_2025_04_16
+
// Equals "o3-mini"
const ChatModelO3Mini = shared.ChatModelO3Mini
@@ -65,6 +94,9 @@ const ChatModelGPT4oAudioPreview2024_10_01 = shared.ChatModelGPT4oAudioPreview20
// Equals "gpt-4o-audio-preview-2024-12-17"
const ChatModelGPT4oAudioPreview2024_12_17 = shared.ChatModelGPT4oAudioPreview2024_12_17
+// Equals "gpt-4o-audio-preview-2025-06-03"
+const ChatModelGPT4oAudioPreview2025_06_03 = shared.ChatModelGPT4oAudioPreview2025_06_03
+
// Equals "gpt-4o-mini-audio-preview"
const ChatModelGPT4oMiniAudioPreview = shared.ChatModelGPT4oMiniAudioPreview
@@ -86,6 +118,9 @@ const ChatModelGPT4oMiniSearchPreview2025_03_11 = shared.ChatModelGPT4oMiniSearc
// Equals "chatgpt-4o-latest"
const ChatModelChatgpt4oLatest = shared.ChatModelChatgpt4oLatest
+// Equals "codex-mini-latest"
+const ChatModelCodexMiniLatest = shared.ChatModelCodexMiniLatest
+
// Equals "gpt-4o-mini"
const ChatModelGPT4oMini = shared.ChatModelGPT4oMini
@@ -254,16 +289,6 @@ type FunctionParameters = shared.FunctionParameters
// This is an alias to an internal type.
type Metadata = shared.Metadata
-// Set of 16 key-value pairs that can be attached to an object. This can be useful
-// for storing additional information about the object in a structured format, and
-// querying for objects via API or the dashboard.
-//
-// Keys are strings with a maximum length of 64 characters. Values are strings with
-// a maximum length of 512 characters.
-//
-// This is an alias to an internal type.
-type MetadataParam = shared.MetadataParam
-
// **o-series models only**
//
// Configuration options for
@@ -272,21 +297,40 @@ type MetadataParam = shared.MetadataParam
// This is an alias to an internal type.
type Reasoning = shared.Reasoning
-// **computer_use_preview only**
+// **Deprecated:** use `summary` instead.
//
// A summary of the reasoning performed by the model. This can be useful for
-// debugging and understanding the model's reasoning process. One of `concise` or
-// `detailed`.
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
//
// This is an alias to an internal type.
type ReasoningGenerateSummary = shared.ReasoningGenerateSummary
+// Equals "auto"
+const ReasoningGenerateSummaryAuto = shared.ReasoningGenerateSummaryAuto
+
// Equals "concise"
const ReasoningGenerateSummaryConcise = shared.ReasoningGenerateSummaryConcise
// Equals "detailed"
const ReasoningGenerateSummaryDetailed = shared.ReasoningGenerateSummaryDetailed
+// A summary of the reasoning performed by the model. This can be useful for
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
+//
+// This is an alias to an internal type.
+type ReasoningSummary = shared.ReasoningSummary
+
+// Equals "auto"
+const ReasoningSummaryAuto = shared.ReasoningSummaryAuto
+
+// Equals "concise"
+const ReasoningSummaryConcise = shared.ReasoningSummaryConcise
+
+// Equals "detailed"
+const ReasoningSummaryDetailed = shared.ReasoningSummaryDetailed
+
// **o-series models only**
//
// Configuration options for
@@ -371,18 +415,26 @@ const ResponsesModelO1Pro = shared.ResponsesModelO1Pro
// Equals "o1-pro-2025-03-19"
const ResponsesModelO1Pro2025_03_19 = shared.ResponsesModelO1Pro2025_03_19
+// Equals "o3-pro"
+const ResponsesModelO3Pro = shared.ResponsesModelO3Pro
+
+// Equals "o3-pro-2025-06-10"
+const ResponsesModelO3Pro2025_06_10 = shared.ResponsesModelO3Pro2025_06_10
+
+// Equals "o3-deep-research"
+const ResponsesModelO3DeepResearch = shared.ResponsesModelO3DeepResearch
+
+// Equals "o3-deep-research-2025-06-26"
+const ResponsesModelO3DeepResearch2025_06_26 = shared.ResponsesModelO3DeepResearch2025_06_26
+
+// Equals "o4-mini-deep-research"
+const ResponsesModelO4MiniDeepResearch = shared.ResponsesModelO4MiniDeepResearch
+
+// Equals "o4-mini-deep-research-2025-06-26"
+const ResponsesModelO4MiniDeepResearch2025_06_26 = shared.ResponsesModelO4MiniDeepResearch2025_06_26
+
// Equals "computer-use-preview"
const ResponsesModelComputerUsePreview = shared.ResponsesModelComputerUsePreview
// Equals "computer-use-preview-2025-03-11"
const ResponsesModelComputerUsePreview2025_03_11 = shared.ResponsesModelComputerUsePreview2025_03_11
-
-func toParam[T comparable](value T, meta resp.Field) param.Opt[T] {
- if meta.IsPresent() {
- return param.NewOpt(value)
- }
- if meta.IsExplicitNull() {
- return param.NullOpt[T]()
- }
- return param.Opt[T]{}
-}
@@ -15,7 +15,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -77,15 +77,14 @@ type ResponseItemList struct {
LastID string `json:"last_id,required"`
// The type of object returned, must be `list`.
Object constant.List `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Data resp.Field
- FirstID resp.Field
- HasMore resp.Field
- LastID resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ Data respjson.Field
+ FirstID respjson.Field
+ HasMore respjson.Field
+ LastID respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -104,7 +103,10 @@ type InputItemListParams struct {
// A limit on the number of objects to be returned. Limit can range between 1 and
// 100, and the default is 20.
Limit param.Opt[int64] `query:"limit,omitzero" json:"-"`
- // The order to return the input items in. Default is `asc`.
+ // Additional fields to include in the response. See the `include` parameter for
+ // Response creation above for more information.
+ Include []ResponseIncludable `query:"include,omitzero" json:"-"`
+ // The order to return the input items in. Default is `desc`.
//
// - `asc`: Return the input items in ascending order.
// - `desc`: Return the input items in descending order.
@@ -114,19 +116,15 @@ type InputItemListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f InputItemListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [InputItemListParams]'s query parameters as `url.Values`.
-func (r InputItemListParams) URLQuery() (v url.Values) {
+func (r InputItemListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
})
}
-// The order to return the input items in. Default is `asc`.
+// The order to return the input items in. Default is `desc`.
//
// - `asc`: Return the input items in ascending order.
// - `desc`: Return the input items in descending order.
@@ -9,19 +9,18 @@ import (
"fmt"
"net/http"
"net/url"
- "reflect"
"strings"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
+ "github.com/openai/openai-go/internal/paramutil"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/packages/ssestream"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// ResponseService contains methods and other services that help with interacting
@@ -98,6 +97,23 @@ func (r *ResponseService) Get(ctx context.Context, responseID string, query Resp
return
}
+// Retrieves a model response with the given ID.
+func (r *ResponseService) GetStreaming(ctx context.Context, responseID string, query ResponseGetParams, opts ...option.RequestOption) (stream *ssestream.Stream[ResponseStreamEventUnion]) {
+ var (
+ raw *http.Response
+ err error
+ )
+ opts = append(r.Options[:], opts...)
+ opts = append([]option.RequestOption{option.WithJSONSet("stream", true)}, opts...)
+ if responseID == "" {
+ err = errors.New("missing required response_id parameter")
+ return
+ }
+ path := fmt.Sprintf("responses/%s", responseID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &raw, opts...)
+ return ssestream.NewStream[ResponseStreamEventUnion](ssestream.NewDecoder(raw), err)
+}
+
// Deletes a model response with the given ID.
func (r *ResponseService) Delete(ctx context.Context, responseID string, opts ...option.RequestOption) (err error) {
opts = append(r.Options[:], opts...)
@@ -111,27 +127,40 @@ func (r *ResponseService) Delete(ctx context.Context, responseID string, opts ..
return
}
+// Cancels a model response with the given ID. Only responses created with the
+// `background` parameter set to `true` can be cancelled.
+// [Learn more](https://platform.openai.com/docs/guides/background).
+func (r *ResponseService) Cancel(ctx context.Context, responseID string, opts ...option.RequestOption) (res *Response, err error) {
+ opts = append(r.Options[:], opts...)
+ if responseID == "" {
+ err = errors.New("missing required response_id parameter")
+ return
+ }
+ path := fmt.Sprintf("responses/%s/cancel", responseID)
+ err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, nil, &res, opts...)
+ return
+}
+
// A tool that controls a virtual computer. Learn more about the
// [computer tool](https://platform.openai.com/docs/guides/tools-computer-use).
type ComputerTool struct {
// The height of the computer display.
- DisplayHeight float64 `json:"display_height,required"`
+ DisplayHeight int64 `json:"display_height,required"`
// The width of the computer display.
- DisplayWidth float64 `json:"display_width,required"`
+ DisplayWidth int64 `json:"display_width,required"`
// The type of computer environment to control.
//
- // Any of "mac", "windows", "ubuntu", "browser".
+ // Any of "windows", "mac", "linux", "ubuntu", "browser".
Environment ComputerToolEnvironment `json:"environment,required"`
// The type of the computer use tool. Always `computer_use_preview`.
Type constant.ComputerUsePreview `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- DisplayHeight resp.Field
- DisplayWidth resp.Field
- Environment resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ DisplayHeight respjson.Field
+ DisplayWidth respjson.Field
+ Environment respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -146,17 +175,18 @@ func (r *ComputerTool) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ComputerToolParam.IsOverridden()
+// ComputerToolParam.Overrides()
func (r ComputerTool) ToParam() ComputerToolParam {
- return param.OverrideObj[ComputerToolParam](r.RawJSON())
+ return param.Override[ComputerToolParam](json.RawMessage(r.RawJSON()))
}
// The type of computer environment to control.
type ComputerToolEnvironment string
const (
- ComputerToolEnvironmentMac ComputerToolEnvironment = "mac"
ComputerToolEnvironmentWindows ComputerToolEnvironment = "windows"
+ ComputerToolEnvironmentMac ComputerToolEnvironment = "mac"
+ ComputerToolEnvironmentLinux ComputerToolEnvironment = "linux"
ComputerToolEnvironmentUbuntu ComputerToolEnvironment = "ubuntu"
ComputerToolEnvironmentBrowser ComputerToolEnvironment = "browser"
)
@@ -167,12 +197,12 @@ const (
// The properties DisplayHeight, DisplayWidth, Environment, Type are required.
type ComputerToolParam struct {
// The height of the computer display.
- DisplayHeight float64 `json:"display_height,required"`
+ DisplayHeight int64 `json:"display_height,required"`
// The width of the computer display.
- DisplayWidth float64 `json:"display_width,required"`
+ DisplayWidth int64 `json:"display_width,required"`
// The type of computer environment to control.
//
- // Any of "mac", "windows", "ubuntu", "browser".
+ // Any of "windows", "mac", "linux", "ubuntu", "browser".
Environment ComputerToolEnvironment `json:"environment,omitzero,required"`
// The type of the computer use tool. Always `computer_use_preview`.
//
@@ -182,13 +212,111 @@ type ComputerToolParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ComputerToolParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ComputerToolParam) MarshalJSON() (data []byte, err error) {
type shadow ComputerToolParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ComputerToolParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// A message input to the model with a role indicating instruction following
+// hierarchy. Instructions given with the `developer` or `system` role take
+// precedence over instructions given with the `user` role. Messages with the
+// `assistant` role are presumed to have been generated by the model in previous
+// interactions.
+type EasyInputMessage struct {
+ // Text, image, or audio input to the model, used to generate a response. Can also
+ // contain previous assistant responses.
+ Content EasyInputMessageContentUnion `json:"content,required"`
+ // The role of the message input. One of `user`, `assistant`, `system`, or
+ // `developer`.
+ //
+ // Any of "user", "assistant", "system", "developer".
+ Role EasyInputMessageRole `json:"role,required"`
+ // The type of the message input. Always `message`.
+ //
+ // Any of "message".
+ Type EasyInputMessageType `json:"type"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Content respjson.Field
+ Role respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EasyInputMessage) RawJSON() string { return r.JSON.raw }
+func (r *EasyInputMessage) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// ToParam converts this EasyInputMessage to a EasyInputMessageParam.
+//
+// Warning: the fields of the param type will not be present. ToParam should only
+// be used at the last possible moment before sending a request. Test for this with
+// EasyInputMessageParam.Overrides()
+func (r EasyInputMessage) ToParam() EasyInputMessageParam {
+ return param.Override[EasyInputMessageParam](json.RawMessage(r.RawJSON()))
+}
+
+// EasyInputMessageContentUnion contains all possible properties and values from
+// [string], [ResponseInputMessageContentList].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString OfInputItemContentList]
+type EasyInputMessageContentUnion struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ // This field will be present if the value is a [ResponseInputMessageContentList]
+ // instead of an object.
+ OfInputItemContentList ResponseInputMessageContentList `json:",inline"`
+ JSON struct {
+ OfString respjson.Field
+ OfInputItemContentList respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u EasyInputMessageContentUnion) AsString() (v string) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u EasyInputMessageContentUnion) AsInputItemContentList() (v ResponseInputMessageContentList) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u EasyInputMessageContentUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *EasyInputMessageContentUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The role of the message input. One of `user`, `assistant`, `system`, or
+// `developer`.
+type EasyInputMessageRole string
+
+const (
+ EasyInputMessageRoleUser EasyInputMessageRole = "user"
+ EasyInputMessageRoleAssistant EasyInputMessageRole = "assistant"
+ EasyInputMessageRoleSystem EasyInputMessageRole = "system"
+ EasyInputMessageRoleDeveloper EasyInputMessageRole = "developer"
+)
+
+// The type of the message input. Always `message`.
+type EasyInputMessageType string
+
+const (
+ EasyInputMessageTypeMessage EasyInputMessageType = "message"
+)
// A message input to the model with a role indicating instruction following
// hierarchy. Instructions given with the `developer` or `system` role take
@@ -213,13 +341,13 @@ type EasyInputMessageParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f EasyInputMessageParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r EasyInputMessageParam) MarshalJSON() (data []byte, err error) {
type shadow EasyInputMessageParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *EasyInputMessageParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -230,13 +358,11 @@ type EasyInputMessageContentUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u EasyInputMessageContentUnionParam) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u EasyInputMessageContentUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[EasyInputMessageContentUnionParam](u.OfString, u.OfInputItemContentList)
+ return param.MarshalUnion(u, u.OfString, u.OfInputItemContentList)
+}
+func (u *EasyInputMessageContentUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *EasyInputMessageContentUnionParam) asAny() any {
@@ -248,24 +374,6 @@ func (u *EasyInputMessageContentUnionParam) asAny() any {
return nil
}
-// The role of the message input. One of `user`, `assistant`, `system`, or
-// `developer`.
-type EasyInputMessageRole string
-
-const (
- EasyInputMessageRoleUser EasyInputMessageRole = "user"
- EasyInputMessageRoleAssistant EasyInputMessageRole = "assistant"
- EasyInputMessageRoleSystem EasyInputMessageRole = "system"
- EasyInputMessageRoleDeveloper EasyInputMessageRole = "developer"
-)
-
-// The type of the message input. Always `message`.
-type EasyInputMessageType string
-
-const (
- EasyInputMessageTypeMessage EasyInputMessageType = "message"
-)
-
// A tool that searches for relevant content from uploaded files. Learn more about
// the
// [file search tool](https://platform.openai.com/docs/guides/tools-file-search).
@@ -274,22 +382,21 @@ type FileSearchTool struct {
Type constant.FileSearch `json:"type,required"`
// The IDs of the vector stores to search.
VectorStoreIDs []string `json:"vector_store_ids,required"`
- // A filter to apply based on file attributes.
- Filters FileSearchToolFiltersUnion `json:"filters"`
+ // A filter to apply.
+ Filters FileSearchToolFiltersUnion `json:"filters,nullable"`
// The maximum number of results to return. This number should be between 1 and 50
// inclusive.
MaxNumResults int64 `json:"max_num_results"`
// Ranking options for search.
RankingOptions FileSearchToolRankingOptions `json:"ranking_options"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- Type resp.Field
- VectorStoreIDs resp.Field
- Filters resp.Field
- MaxNumResults resp.Field
- RankingOptions resp.Field
- ExtraFields map[string]resp.Field
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Type respjson.Field
+ VectorStoreIDs respjson.Field
+ Filters respjson.Field
+ MaxNumResults respjson.Field
+ RankingOptions respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -304,9 +411,9 @@ func (r *FileSearchTool) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// FileSearchToolParam.IsOverridden()
+// FileSearchToolParam.Overrides()
func (r FileSearchTool) ToParam() FileSearchToolParam {
- return param.OverrideObj[FileSearchToolParam](r.RawJSON())
+ return param.Override[FileSearchToolParam](json.RawMessage(r.RawJSON()))
}
// FileSearchToolFiltersUnion contains all possible properties and values from
@@ -322,10 +429,10 @@ type FileSearchToolFiltersUnion struct {
// This field is from variant [shared.CompoundFilter].
Filters []shared.ComparisonFilter `json:"filters"`
JSON struct {
- Key resp.Field
- Type resp.Field
- Value resp.Field
- Filters resp.Field
+ Key respjson.Field
+ Type respjson.Field
+ Value respjson.Field
+ Filters respjson.Field
raw string
} `json:"-"`
}
@@ -357,12 +464,11 @@ type FileSearchToolRankingOptions struct {
// closer to 1 will attempt to return only the most relevant results, but may
// return fewer results.
ScoreThreshold float64 `json:"score_threshold"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Ranker resp.Field
- ScoreThreshold resp.Field
- ExtraFields map[string]resp.Field
+ Ranker respjson.Field
+ ScoreThreshold respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -384,7 +490,7 @@ type FileSearchToolParam struct {
// The maximum number of results to return. This number should be between 1 and 50
// inclusive.
MaxNumResults param.Opt[int64] `json:"max_num_results,omitzero"`
- // A filter to apply based on file attributes.
+ // A filter to apply.
Filters FileSearchToolFiltersUnionParam `json:"filters,omitzero"`
// Ranking options for search.
RankingOptions FileSearchToolRankingOptionsParam `json:"ranking_options,omitzero"`
@@ -395,13 +501,13 @@ type FileSearchToolParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FileSearchToolParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r FileSearchToolParam) MarshalJSON() (data []byte, err error) {
type shadow FileSearchToolParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FileSearchToolParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -412,11 +518,11 @@ type FileSearchToolFiltersUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FileSearchToolFiltersUnionParam) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u FileSearchToolFiltersUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FileSearchToolFiltersUnionParam](u.OfComparisonFilter, u.OfCompoundFilter)
+ return param.MarshalUnion(u, u.OfComparisonFilter, u.OfCompoundFilter)
+}
+func (u *FileSearchToolFiltersUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *FileSearchToolFiltersUnionParam) asAny() any {
@@ -475,19 +581,17 @@ type FileSearchToolRankingOptionsParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FileSearchToolRankingOptionsParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r FileSearchToolRankingOptionsParam) MarshalJSON() (data []byte, err error) {
type shadow FileSearchToolRankingOptionsParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FileSearchToolRankingOptionsParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[FileSearchToolRankingOptionsParam](
- "Ranker", false, "auto", "default-2024-11-15",
+ "ranker", "auto", "default-2024-11-15",
)
}
@@ -498,7 +602,7 @@ type FunctionTool struct {
// The name of the function to call.
Name string `json:"name,required"`
// A JSON schema object describing the parameters of the function.
- Parameters map[string]interface{} `json:"parameters,required"`
+ Parameters map[string]any `json:"parameters,required"`
// Whether to enforce strict parameter validation. Default `true`.
Strict bool `json:"strict,required"`
// The type of the function tool. Always `function`.
@@ -506,15 +610,14 @@ type FunctionTool struct {
// A description of the function. Used by the model to determine whether or not to
// call the function.
Description string `json:"description,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- Name resp.Field
- Parameters resp.Field
- Strict resp.Field
- Type resp.Field
- Description resp.Field
- ExtraFields map[string]resp.Field
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ Name respjson.Field
+ Parameters respjson.Field
+ Strict respjson.Field
+ Type respjson.Field
+ Description respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -529,9 +632,9 @@ func (r *FunctionTool) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// FunctionToolParam.IsOverridden()
+// FunctionToolParam.Overrides()
func (r FunctionTool) ToParam() FunctionToolParam {
- return param.OverrideObj[FunctionToolParam](r.RawJSON())
+ return param.Override[FunctionToolParam](json.RawMessage(r.RawJSON()))
}
// Defines a function in your own code the model can choose to call. Learn more
@@ -540,12 +643,12 @@ func (r FunctionTool) ToParam() FunctionToolParam {
//
// The properties Name, Parameters, Strict, Type are required.
type FunctionToolParam struct {
+ // Whether to enforce strict parameter validation. Default `true`.
+ Strict param.Opt[bool] `json:"strict,omitzero,required"`
+ // A JSON schema object describing the parameters of the function.
+ Parameters map[string]any `json:"parameters,omitzero,required"`
// The name of the function to call.
Name string `json:"name,required"`
- // A JSON schema object describing the parameters of the function.
- Parameters map[string]interface{} `json:"parameters,omitzero,required"`
- // Whether to enforce strict parameter validation. Default `true`.
- Strict bool `json:"strict,required"`
// A description of the function. Used by the model to determine whether or not to
// call the function.
Description param.Opt[string] `json:"description,omitzero"`
@@ -556,13 +659,13 @@ type FunctionToolParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FunctionToolParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r FunctionToolParam) MarshalJSON() (data []byte, err error) {
type shadow FunctionToolParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FunctionToolParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type Response struct {
// Unique identifier for this Response.
@@ -573,13 +676,12 @@ type Response struct {
Error ResponseError `json:"error,required"`
// Details about why the response is incomplete.
IncompleteDetails ResponseIncompleteDetails `json:"incomplete_details,required"`
- // Inserts a system (or developer) message as the first item in the model's
- // context.
+ // A system (or developer) message inserted into the model's context.
//
// When using along with `previous_response_id`, the instructions from a previous
- // response will be not be carried over to the next response. This makes it simple
- // to swap out system (or developer) messages in new responses.
- Instructions string `json:"instructions,required"`
+ // response will not be carried over to the next response. This makes it simple to
+ // swap out system (or developer) messages in new responses.
+ Instructions ResponseInstructionsUnion `json:"instructions,required"`
// Set of 16 key-value pairs that can be attached to an object. This can be useful
// for storing additional information about the object in a structured format, and
// querying for objects via API or the dashboard.
@@ -587,7 +689,7 @@ type Response struct {
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
Metadata shared.Metadata `json:"metadata,required"`
- // Model ID used to generate the response, like `gpt-4o` or `o1`. OpenAI offers a
+ // Model ID used to generate the response, like `gpt-4o` or `o3`. OpenAI offers a
// wide range of models with different capabilities, performance characteristics,
// and price points. Refer to the
// [model guide](https://platform.openai.com/docs/models) to browse and compare
@@ -635,23 +737,55 @@ type Response struct {
//
// We generally recommend altering this or `temperature` but not both.
TopP float64 `json:"top_p,required"`
+ // Whether to run the model response in the background.
+ // [Learn more](https://platform.openai.com/docs/guides/background).
+ Background bool `json:"background,nullable"`
// An upper bound for the number of tokens that can be generated for a response,
// including visible output tokens and
// [reasoning tokens](https://platform.openai.com/docs/guides/reasoning).
MaxOutputTokens int64 `json:"max_output_tokens,nullable"`
+ // The maximum number of total calls to built-in tools that can be processed in a
+ // response. This maximum number applies across all built-in tool calls, not per
+ // individual tool. Any further attempts to call a tool by the model will be
+ // ignored.
+ MaxToolCalls int64 `json:"max_tool_calls,nullable"`
// The unique ID of the previous response to the model. Use this to create
// multi-turn conversations. Learn more about
// [conversation state](https://platform.openai.com/docs/guides/conversation-state).
PreviousResponseID string `json:"previous_response_id,nullable"`
+ // Reference to a prompt template and its variables.
+ // [Learn more](https://platform.openai.com/docs/guides/text?api-mode=responses#reusable-prompts).
+ Prompt ResponsePrompt `json:"prompt,nullable"`
// **o-series models only**
//
// Configuration options for
// [reasoning models](https://platform.openai.com/docs/guides/reasoning).
Reasoning shared.Reasoning `json:"reasoning,nullable"`
+ // Specifies the processing type used for serving the request.
+ //
+ // - If set to 'auto', then the request will be processed with the service tier
+ // configured in the Project settings. Unless otherwise configured, the Project
+ // will use 'default'.
+ // - If set to 'default', then the requset will be processed with the standard
+ // pricing and performance for the selected model.
+ // - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+ // 'priority', then the request will be processed with the corresponding service
+ // tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+ // Priority processing.
+ // - When not set, the default behavior is 'auto'.
+ //
+ // When the `service_tier` parameter is set, the response body will include the
+ // `service_tier` value based on the processing mode actually used to serve the
+ // request. This response value may be different from the value set in the
+ // parameter.
+ //
+ // Any of "auto", "default", "flex", "scale", "priority".
+ ServiceTier ResponseServiceTier `json:"service_tier,nullable"`
// The status of the response generation. One of `completed`, `failed`,
- // `in_progress`, or `incomplete`.
+ // `in_progress`, `cancelled`, `queued`, or `incomplete`.
//
- // Any of "completed", "failed", "in_progress", "incomplete".
+ // Any of "completed", "failed", "in_progress", "cancelled", "queued",
+ // "incomplete".
Status ResponseStatus `json:"status"`
// Configuration options for a text response from the model. Can be plain text or
// structured JSON data. Learn more:
@@ -659,6 +793,9 @@ type Response struct {
// - [Text inputs and outputs](https://platform.openai.com/docs/guides/text)
// - [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs)
Text ResponseTextConfig `json:"text"`
+ // An integer between 0 and 20 specifying the number of most likely tokens to
+ // return at each token position, each with an associated log probability.
+ TopLogprobs int64 `json:"top_logprobs,nullable"`
// The truncation strategy to use for the model response.
//
// - `auto`: If the context of this response and previous ones exceeds the model's
@@ -672,36 +809,40 @@ type Response struct {
// Represents token usage details including input tokens, output tokens, a
// breakdown of output tokens, and the total tokens used.
Usage ResponseUsage `json:"usage"`
- // A unique identifier representing your end-user, which can help OpenAI to monitor
- // and detect abuse.
+ // A stable identifier for your end-users. Used to boost cache hit rates by better
+ // bucketing similar requests and to help OpenAI detect and prevent abuse.
// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
User string `json:"user"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
- JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Error resp.Field
- IncompleteDetails resp.Field
- Instructions resp.Field
- Metadata resp.Field
- Model resp.Field
- Object resp.Field
- Output resp.Field
- ParallelToolCalls resp.Field
- Temperature resp.Field
- ToolChoice resp.Field
- Tools resp.Field
- TopP resp.Field
- MaxOutputTokens resp.Field
- PreviousResponseID resp.Field
- Reasoning resp.Field
- Status resp.Field
- Text resp.Field
- Truncation resp.Field
- Usage resp.Field
- User resp.Field
- ExtraFields map[string]resp.Field
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Error respjson.Field
+ IncompleteDetails respjson.Field
+ Instructions respjson.Field
+ Metadata respjson.Field
+ Model respjson.Field
+ Object respjson.Field
+ Output respjson.Field
+ ParallelToolCalls respjson.Field
+ Temperature respjson.Field
+ ToolChoice respjson.Field
+ Tools respjson.Field
+ TopP respjson.Field
+ Background respjson.Field
+ MaxOutputTokens respjson.Field
+ MaxToolCalls respjson.Field
+ PreviousResponseID respjson.Field
+ Prompt respjson.Field
+ Reasoning respjson.Field
+ ServiceTier respjson.Field
+ Status respjson.Field
+ Text respjson.Field
+ TopLogprobs respjson.Field
+ Truncation respjson.Field
+ Usage respjson.Field
+ User respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -730,11 +871,10 @@ type ResponseIncompleteDetails struct {
//
// Any of "max_output_tokens", "content_filter".
Reason string `json:"reason"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Reason resp.Field
- ExtraFields map[string]resp.Field
+ Reason respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -745,8 +885,45 @@ func (r *ResponseIncompleteDetails) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+// ResponseInstructionsUnion contains all possible properties and values from
+// [string], [[]ResponseInputItemUnion].
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+//
+// If the underlying value is not a json object, one of the following properties
+// will be valid: OfString OfInputItemList]
+type ResponseInstructionsUnion struct {
+ // This field will be present if the value is a [string] instead of an object.
+ OfString string `json:",inline"`
+ // This field will be present if the value is a [[]ResponseInputItemUnion] instead
+ // of an object.
+ OfInputItemList []ResponseInputItemUnion `json:",inline"`
+ JSON struct {
+ OfString respjson.Field
+ OfInputItemList respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (u ResponseInstructionsUnion) AsString() (v string) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u ResponseInstructionsUnion) AsInputItemList() (v []ResponseInputItemUnion) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u ResponseInstructionsUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *ResponseInstructionsUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
// ResponseToolChoiceUnion contains all possible properties and values from
-// [ToolChoiceOptions], [ToolChoiceTypes], [ToolChoiceFunction].
+// [ToolChoiceOptions], [ToolChoiceTypes], [ToolChoiceFunction], [ToolChoiceMcp].
//
// Use the methods beginning with 'As' to cast the union to one of its variants.
//
@@ -757,12 +934,14 @@ type ResponseToolChoiceUnion struct {
// object.
OfToolChoiceMode ToolChoiceOptions `json:",inline"`
Type string `json:"type"`
- // This field is from variant [ToolChoiceFunction].
- Name string `json:"name"`
- JSON struct {
- OfToolChoiceMode resp.Field
- Type resp.Field
- Name resp.Field
+ Name string `json:"name"`
+ // This field is from variant [ToolChoiceMcp].
+ ServerLabel string `json:"server_label"`
+ JSON struct {
+ OfToolChoiceMode respjson.Field
+ Type respjson.Field
+ Name respjson.Field
+ ServerLabel respjson.Field
raw string
} `json:"-"`
}
@@ -782,6 +961,11 @@ func (u ResponseToolChoiceUnion) AsFunctionTool() (v ToolChoiceFunction) {
return
}
+func (u ResponseToolChoiceUnion) AsMcpTool() (v ToolChoiceMcp) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
// Returns the unmodified JSON received from the API
func (u ResponseToolChoiceUnion) RawJSON() string { return u.JSON.raw }
@@ -789,6 +973,33 @@ func (r *ResponseToolChoiceUnion) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+// Specifies the processing type used for serving the request.
+//
+// - If set to 'auto', then the request will be processed with the service tier
+// configured in the Project settings. Unless otherwise configured, the Project
+// will use 'default'.
+// - If set to 'default', then the requset will be processed with the standard
+// pricing and performance for the selected model.
+// - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or
+// 'priority', then the request will be processed with the corresponding service
+// tier. [Contact sales](https://openai.com/contact-sales) to learn more about
+// Priority processing.
+// - When not set, the default behavior is 'auto'.
+//
+// When the `service_tier` parameter is set, the response body will include the
+// `service_tier` value based on the processing mode actually used to serve the
+// request. This response value may be different from the value set in the
+// parameter.
+type ResponseServiceTier string
+
+const (
+ ResponseServiceTierAuto ResponseServiceTier = "auto"
+ ResponseServiceTierDefault ResponseServiceTier = "default"
+ ResponseServiceTierFlex ResponseServiceTier = "flex"
+ ResponseServiceTierScale ResponseServiceTier = "scale"
+ ResponseServiceTierPriority ResponseServiceTier = "priority"
+)
+
// The truncation strategy to use for the model response.
//
// - `auto`: If the context of this response and previous ones exceeds the model's
@@ -807,15 +1018,17 @@ const (
type ResponseAudioDeltaEvent struct {
// A chunk of Base64 encoded response audio bytes.
Delta string `json:"delta,required"`
+ // A sequence number for this chunk of the stream response.
+ SequenceNumber int64 `json:"sequence_number,required"`
// The type of the event. Always `response.audio.delta`.
Type constant.ResponseAudioDelta `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Delta resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Delta respjson.Field
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -827,14 +1040,16 @@ func (r *ResponseAudioDeltaEvent) UnmarshalJSON(data []byte) error {
// Emitted when the audio response is complete.
type ResponseAudioDoneEvent struct {
+ // The sequence number of the delta.
+ SequenceNumber int64 `json:"sequence_number,required"`
// The type of the event. Always `response.audio.done`.
Type constant.ResponseAudioDone `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -848,15 +1063,17 @@ func (r *ResponseAudioDoneEvent) UnmarshalJSON(data []byte) error {
type ResponseAudioTranscriptDeltaEvent struct {
// The partial transcript of the audio response.
Delta string `json:"delta,required"`
+ // The sequence number of this event.
+ SequenceNumber int64 `json:"sequence_number,required"`
// The type of the event. Always `response.audio.transcript.delta`.
Type constant.ResponseAudioTranscriptDelta `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Delta resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Delta respjson.Field
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -868,14 +1085,16 @@ func (r *ResponseAudioTranscriptDeltaEvent) UnmarshalJSON(data []byte) error {
// Emitted when the full audio transcript is completed.
type ResponseAudioTranscriptDoneEvent struct {
+ // The sequence number of this event.
+ SequenceNumber int64 `json:"sequence_number,required"`
// The type of the event. Always `response.audio.transcript.done`.
Type constant.ResponseAudioTranscriptDone `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -885,22 +1104,28 @@ func (r *ResponseAudioTranscriptDoneEvent) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
-// Emitted when a partial code snippet is added by the code interpreter.
+// Emitted when a partial code snippet is streamed by the code interpreter.
type ResponseCodeInterpreterCallCodeDeltaEvent struct {
- // The partial code snippet added by the code interpreter.
+ // The partial code snippet being streamed by the code interpreter.
Delta string `json:"delta,required"`
- // The index of the output item that the code interpreter call is in progress.
+ // The unique identifier of the code interpreter tool call item.
+ ItemID string `json:"item_id,required"`
+ // The index of the output item in the response for which the code is being
+ // streamed.
OutputIndex int64 `json:"output_index,required"`
- // The type of the event. Always `response.code_interpreter_call.code.delta`.
+ // The sequence number of this event, used to order streaming events.
+ SequenceNumber int64 `json:"sequence_number,required"`
+ // The type of the event. Always `response.code_interpreter_call_code.delta`.
Type constant.ResponseCodeInterpreterCallCodeDelta `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Delta resp.Field
- OutputIndex resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Delta respjson.Field
+ ItemID respjson.Field
+ OutputIndex respjson.Field
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -910,22 +1135,27 @@ func (r *ResponseCodeInterpreterCallCodeDeltaEvent) UnmarshalJSON(data []byte) e
return apijson.UnmarshalRoot(data, r)
}
-// Emitted when code snippet output is finalized by the code interpreter.
+// Emitted when the code snippet is finalized by the code interpreter.
type ResponseCodeInterpreterCallCodeDoneEvent struct {
// The final code snippet output by the code interpreter.
Code string `json:"code,required"`
- // The index of the output item that the code interpreter call is in progress.
+ // The unique identifier of the code interpreter tool call item.
+ ItemID string `json:"item_id,required"`
+ // The index of the output item in the response for which the code is finalized.
OutputIndex int64 `json:"output_index,required"`
- // The type of the event. Always `response.code_interpreter_call.code.done`.
+ // The sequence number of this event, used to order streaming events.
+ SequenceNumber int64 `json:"sequence_number,required"`
+ // The type of the event. Always `response.code_interpreter_call_code.done`.
Type constant.ResponseCodeInterpreterCallCodeDone `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- OutputIndex resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ Code respjson.Field
+ ItemID respjson.Field
+ OutputIndex respjson.Field
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -937,20 +1167,23 @@ func (r *ResponseCodeInterpreterCallCodeDoneEvent) UnmarshalJSON(data []byte) er
// Emitted when the code interpreter call is completed.
type ResponseCodeInterpreterCallCompletedEvent struct {
- // A tool call to run code.
- CodeInterpreterCall ResponseCodeInterpreterToolCall `json:"code_interpreter_call,required"`
- // The index of the output item that the code interpreter call is in progress.
+ // The unique identifier of the code interpreter tool call item.
+ ItemID string `json:"item_id,required"`
+ // The index of the output item in the response for which the code interpreter call
+ // is completed.
OutputIndex int64 `json:"output_index,required"`
+ // The sequence number of this event, used to order streaming events.
+ SequenceNumber int64 `json:"sequence_number,required"`
// The type of the event. Always `response.code_interpreter_call.completed`.
Type constant.ResponseCodeInterpreterCallCompleted `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- CodeInterpreterCall resp.Field
- OutputIndex resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
- raw string
+ ItemID respjson.Field
+ OutputIndex respjson.Field
+ SequenceNumber respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
} `json:"-"`
}
@@ -23,54 +23,82 @@ type Assistant string // Always "assistant"
type AssistantDeleted string // Always "assistant.deleted"
type Auto string // Always "auto"
type Batch string // Always "batch"
+type BatchCancelled string // Always "batch.cancelled"
+type BatchCompleted string // Always "batch.completed"
+type BatchExpired string // Always "batch.expired"
+type BatchFailed string // Always "batch.failed"
type ChatCompletion string // Always "chat.completion"
type ChatCompletionChunk string // Always "chat.completion.chunk"
type ChatCompletionDeleted string // Always "chat.completion.deleted"
+type CheckpointPermission string // Always "checkpoint.permission"
type Click string // Always "click"
type CodeInterpreter string // Always "code_interpreter"
type CodeInterpreterCall string // Always "code_interpreter_call"
type ComputerCallOutput string // Always "computer_call_output"
type ComputerScreenshot string // Always "computer_screenshot"
type ComputerUsePreview string // Always "computer_use_preview"
+type ContainerFileCitation string // Always "container_file_citation"
+type ContainerFile string // Always "container.file"
type Content string // Always "content"
type Developer string // Always "developer"
type DoubleClick string // Always "double_click"
type Drag string // Always "drag"
+type Duration string // Always "duration"
type Embedding string // Always "embedding"
type Error string // Always "error"
+type EvalRunCanceled string // Always "eval.run.canceled"
+type EvalRunFailed string // Always "eval.run.failed"
+type EvalRunSucceeded string // Always "eval.run.succeeded"
+type Exec string // Always "exec"
type File string // Always "file"
type FileCitation string // Always "file_citation"
type FilePath string // Always "file_path"
type FileSearch string // Always "file_search"
type FileSearchCall string // Always "file_search_call"
-type Files string // Always "files"
+type Find string // Always "find"
type FineTuningJob string // Always "fine_tuning.job"
+type FineTuningJobCancelled string // Always "fine_tuning.job.cancelled"
type FineTuningJobCheckpoint string // Always "fine_tuning.job.checkpoint"
type FineTuningJobEvent string // Always "fine_tuning.job.event"
+type FineTuningJobFailed string // Always "fine_tuning.job.failed"
+type FineTuningJobSucceeded string // Always "fine_tuning.job.succeeded"
type Function string // Always "function"
type FunctionCall string // Always "function_call"
type FunctionCallOutput string // Always "function_call_output"
type Image string // Always "image"
type ImageFile string // Always "image_file"
+type ImageGeneration string // Always "image_generation"
+type ImageGenerationCall string // Always "image_generation_call"
type ImageURL string // Always "image_url"
type InputAudio string // Always "input_audio"
type InputFile string // Always "input_file"
type InputImage string // Always "input_image"
type InputText string // Always "input_text"
-type ItemReference string // Always "item_reference"
type JSONObject string // Always "json_object"
type JSONSchema string // Always "json_schema"
type Keypress string // Always "keypress"
+type LabelModel string // Always "label_model"
type LastActiveAt string // Always "last_active_at"
type List string // Always "list"
+type LocalShell string // Always "local_shell"
+type LocalShellCall string // Always "local_shell_call"
+type LocalShellCallOutput string // Always "local_shell_call_output"
type Logs string // Always "logs"
+type Mcp string // Always "mcp"
+type McpApprovalRequest string // Always "mcp_approval_request"
+type McpApprovalResponse string // Always "mcp_approval_response"
+type McpCall string // Always "mcp_call"
+type McpListTools string // Always "mcp_list_tools"
type Message string // Always "message"
type MessageCreation string // Always "message_creation"
type Model string // Always "model"
type Move string // Always "move"
+type Multi string // Always "multi"
+type OpenPage string // Always "open_page"
type Other string // Always "other"
type OutputAudio string // Always "output_audio"
type OutputText string // Always "output_text"
+type Python string // Always "python"
type Reasoning string // Always "reasoning"
type Refusal string // Always "refusal"
type Response string // Always "response"
@@ -78,8 +106,9 @@ type ResponseAudioDelta string // Always "response.audio.de
type ResponseAudioDone string // Always "response.audio.done"
type ResponseAudioTranscriptDelta string // Always "response.audio.transcript.delta"
type ResponseAudioTranscriptDone string // Always "response.audio.transcript.done"
-type ResponseCodeInterpreterCallCodeDelta string // Always "response.code_interpreter_call.code.delta"
-type ResponseCodeInterpreterCallCodeDone string // Always "response.code_interpreter_call.code.done"
+type ResponseCancelled string // Always "response.cancelled"
+type ResponseCodeInterpreterCallCodeDelta string // Always "response.code_interpreter_call_code.delta"
+type ResponseCodeInterpreterCallCodeDone string // Always "response.code_interpreter_call_code.done"
type ResponseCodeInterpreterCallCompleted string // Always "response.code_interpreter_call.completed"
type ResponseCodeInterpreterCallInProgress string // Always "response.code_interpreter_call.in_progress"
type ResponseCodeInterpreterCallInterpreting string // Always "response.code_interpreter_call.interpreting"
@@ -93,26 +122,51 @@ type ResponseFileSearchCallInProgress string // Always "response.file_sea
type ResponseFileSearchCallSearching string // Always "response.file_search_call.searching"
type ResponseFunctionCallArgumentsDelta string // Always "response.function_call_arguments.delta"
type ResponseFunctionCallArgumentsDone string // Always "response.function_call_arguments.done"
+type ResponseImageGenerationCallCompleted string // Always "response.image_generation_call.completed"
+type ResponseImageGenerationCallGenerating string // Always "response.image_generation_call.generating"
+type ResponseImageGenerationCallInProgress string // Always "response.image_generation_call.in_progress"
+type ResponseImageGenerationCallPartialImage string // Always "response.image_generation_call.partial_image"
type ResponseInProgress string // Always "response.in_progress"
type ResponseIncomplete string // Always "response.incomplete"
+type ResponseMcpCallArgumentsDelta string // Always "response.mcp_call.arguments_delta"
+type ResponseMcpCallArgumentsDone string // Always "response.mcp_call.arguments_done"
+type ResponseMcpCallCompleted string // Always "response.mcp_call.completed"
+type ResponseMcpCallFailed string // Always "response.mcp_call.failed"
+type ResponseMcpCallInProgress string // Always "response.mcp_call.in_progress"
+type ResponseMcpListToolsCompleted string // Always "response.mcp_list_tools.completed"
+type ResponseMcpListToolsFailed string // Always "response.mcp_list_tools.failed"
+type ResponseMcpListToolsInProgress string // Always "response.mcp_list_tools.in_progress"
type ResponseOutputItemAdded string // Always "response.output_item.added"
type ResponseOutputItemDone string // Always "response.output_item.done"
-type ResponseOutputTextAnnotationAdded string // Always "response.output_text.annotation.added"
+type ResponseOutputTextAnnotationAdded string // Always "response.output_text_annotation.added"
type ResponseOutputTextDelta string // Always "response.output_text.delta"
type ResponseOutputTextDone string // Always "response.output_text.done"
+type ResponseQueued string // Always "response.queued"
+type ResponseReasoningSummaryPartAdded string // Always "response.reasoning_summary_part.added"
+type ResponseReasoningSummaryPartDone string // Always "response.reasoning_summary_part.done"
+type ResponseReasoningSummaryTextDelta string // Always "response.reasoning_summary_text.delta"
+type ResponseReasoningSummaryTextDone string // Always "response.reasoning_summary_text.done"
+type ResponseReasoningSummaryDelta string // Always "response.reasoning_summary.delta"
+type ResponseReasoningSummaryDone string // Always "response.reasoning_summary.done"
+type ResponseReasoningDelta string // Always "response.reasoning.delta"
+type ResponseReasoningDone string // Always "response.reasoning.done"
type ResponseRefusalDelta string // Always "response.refusal.delta"
type ResponseRefusalDone string // Always "response.refusal.done"
type ResponseWebSearchCallCompleted string // Always "response.web_search_call.completed"
type ResponseWebSearchCallInProgress string // Always "response.web_search_call.in_progress"
type ResponseWebSearchCallSearching string // Always "response.web_search_call.searching"
+type ScoreModel string // Always "score_model"
type Screenshot string // Always "screenshot"
type Scroll string // Always "scroll"
+type Search string // Always "search"
type Static string // Always "static"
+type StringCheck string // Always "string_check"
type SubmitToolOutputs string // Always "submit_tool_outputs"
type SummaryText string // Always "summary_text"
type System string // Always "system"
type Text string // Always "text"
type TextCompletion string // Always "text_completion"
+type TextSimilarity string // Always "text_similarity"
type Thread string // Always "thread"
type ThreadCreated string // Always "thread.created"
type ThreadDeleted string // Always "thread.deleted"
@@ -142,6 +196,7 @@ type ThreadRunStepDelta string // Always "thread.run.step.d
type ThreadRunStepExpired string // Always "thread.run.step.expired"
type ThreadRunStepFailed string // Always "thread.run.step.failed"
type ThreadRunStepInProgress string // Always "thread.run.step.in_progress"
+type Tokens string // Always "tokens"
type Tool string // Always "tool"
type ToolCalls string // Always "tool_calls"
type TranscriptTextDelta string // Always "transcript.text.delta"
@@ -162,77 +217,106 @@ type Wait string // Always "wait"
type Wandb string // Always "wandb"
type WebSearchCall string // Always "web_search_call"
-func (c Approximate) Default() Approximate { return "approximate" }
-func (c Assistant) Default() Assistant { return "assistant" }
-func (c AssistantDeleted) Default() AssistantDeleted { return "assistant.deleted" }
-func (c Auto) Default() Auto { return "auto" }
-func (c Batch) Default() Batch { return "batch" }
-func (c ChatCompletion) Default() ChatCompletion { return "chat.completion" }
-func (c ChatCompletionChunk) Default() ChatCompletionChunk { return "chat.completion.chunk" }
-func (c ChatCompletionDeleted) Default() ChatCompletionDeleted { return "chat.completion.deleted" }
-func (c Click) Default() Click { return "click" }
-func (c CodeInterpreter) Default() CodeInterpreter { return "code_interpreter" }
-func (c CodeInterpreterCall) Default() CodeInterpreterCall { return "code_interpreter_call" }
-func (c ComputerCallOutput) Default() ComputerCallOutput { return "computer_call_output" }
-func (c ComputerScreenshot) Default() ComputerScreenshot { return "computer_screenshot" }
-func (c ComputerUsePreview) Default() ComputerUsePreview { return "computer_use_preview" }
-func (c Content) Default() Content { return "content" }
-func (c Developer) Default() Developer { return "developer" }
-func (c DoubleClick) Default() DoubleClick { return "double_click" }
-func (c Drag) Default() Drag { return "drag" }
-func (c Embedding) Default() Embedding { return "embedding" }
-func (c Error) Default() Error { return "error" }
-func (c File) Default() File { return "file" }
-func (c FileCitation) Default() FileCitation { return "file_citation" }
-func (c FilePath) Default() FilePath { return "file_path" }
-func (c FileSearch) Default() FileSearch { return "file_search" }
-func (c FileSearchCall) Default() FileSearchCall { return "file_search_call" }
-func (c Files) Default() Files { return "files" }
-func (c FineTuningJob) Default() FineTuningJob { return "fine_tuning.job" }
+func (c Approximate) Default() Approximate { return "approximate" }
+func (c Assistant) Default() Assistant { return "assistant" }
+func (c AssistantDeleted) Default() AssistantDeleted { return "assistant.deleted" }
+func (c Auto) Default() Auto { return "auto" }
+func (c Batch) Default() Batch { return "batch" }
+func (c BatchCancelled) Default() BatchCancelled { return "batch.cancelled" }
+func (c BatchCompleted) Default() BatchCompleted { return "batch.completed" }
+func (c BatchExpired) Default() BatchExpired { return "batch.expired" }
+func (c BatchFailed) Default() BatchFailed { return "batch.failed" }
+func (c ChatCompletion) Default() ChatCompletion { return "chat.completion" }
+func (c ChatCompletionChunk) Default() ChatCompletionChunk { return "chat.completion.chunk" }
+func (c ChatCompletionDeleted) Default() ChatCompletionDeleted { return "chat.completion.deleted" }
+func (c CheckpointPermission) Default() CheckpointPermission { return "checkpoint.permission" }
+func (c Click) Default() Click { return "click" }
+func (c CodeInterpreter) Default() CodeInterpreter { return "code_interpreter" }
+func (c CodeInterpreterCall) Default() CodeInterpreterCall { return "code_interpreter_call" }
+func (c ComputerCallOutput) Default() ComputerCallOutput { return "computer_call_output" }
+func (c ComputerScreenshot) Default() ComputerScreenshot { return "computer_screenshot" }
+func (c ComputerUsePreview) Default() ComputerUsePreview { return "computer_use_preview" }
+func (c ContainerFileCitation) Default() ContainerFileCitation { return "container_file_citation" }
+func (c ContainerFile) Default() ContainerFile { return "container.file" }
+func (c Content) Default() Content { return "content" }
+func (c Developer) Default() Developer { return "developer" }
+func (c DoubleClick) Default() DoubleClick { return "double_click" }
+func (c Drag) Default() Drag { return "drag" }
+func (c Duration) Default() Duration { return "duration" }
+func (c Embedding) Default() Embedding { return "embedding" }
+func (c Error) Default() Error { return "error" }
+func (c EvalRunCanceled) Default() EvalRunCanceled { return "eval.run.canceled" }
+func (c EvalRunFailed) Default() EvalRunFailed { return "eval.run.failed" }
+func (c EvalRunSucceeded) Default() EvalRunSucceeded { return "eval.run.succeeded" }
+func (c Exec) Default() Exec { return "exec" }
+func (c File) Default() File { return "file" }
+func (c FileCitation) Default() FileCitation { return "file_citation" }
+func (c FilePath) Default() FilePath { return "file_path" }
+func (c FileSearch) Default() FileSearch { return "file_search" }
+func (c FileSearchCall) Default() FileSearchCall { return "file_search_call" }
+func (c Find) Default() Find { return "find" }
+func (c FineTuningJob) Default() FineTuningJob { return "fine_tuning.job" }
+func (c FineTuningJobCancelled) Default() FineTuningJobCancelled { return "fine_tuning.job.cancelled" }
func (c FineTuningJobCheckpoint) Default() FineTuningJobCheckpoint {
return "fine_tuning.job.checkpoint"
}
-func (c FineTuningJobEvent) Default() FineTuningJobEvent { return "fine_tuning.job.event" }
-func (c Function) Default() Function { return "function" }
-func (c FunctionCall) Default() FunctionCall { return "function_call" }
-func (c FunctionCallOutput) Default() FunctionCallOutput { return "function_call_output" }
-func (c Image) Default() Image { return "image" }
-func (c ImageFile) Default() ImageFile { return "image_file" }
-func (c ImageURL) Default() ImageURL { return "image_url" }
-func (c InputAudio) Default() InputAudio { return "input_audio" }
-func (c InputFile) Default() InputFile { return "input_file" }
-func (c InputImage) Default() InputImage { return "input_image" }
-func (c InputText) Default() InputText { return "input_text" }
-func (c ItemReference) Default() ItemReference { return "item_reference" }
-func (c JSONObject) Default() JSONObject { return "json_object" }
-func (c JSONSchema) Default() JSONSchema { return "json_schema" }
-func (c Keypress) Default() Keypress { return "keypress" }
-func (c LastActiveAt) Default() LastActiveAt { return "last_active_at" }
-func (c List) Default() List { return "list" }
-func (c Logs) Default() Logs { return "logs" }
-func (c Message) Default() Message { return "message" }
-func (c MessageCreation) Default() MessageCreation { return "message_creation" }
-func (c Model) Default() Model { return "model" }
-func (c Move) Default() Move { return "move" }
-func (c Other) Default() Other { return "other" }
-func (c OutputAudio) Default() OutputAudio { return "output_audio" }
-func (c OutputText) Default() OutputText { return "output_text" }
-func (c Reasoning) Default() Reasoning { return "reasoning" }
-func (c Refusal) Default() Refusal { return "refusal" }
-func (c Response) Default() Response { return "response" }
-func (c ResponseAudioDelta) Default() ResponseAudioDelta { return "response.audio.delta" }
-func (c ResponseAudioDone) Default() ResponseAudioDone { return "response.audio.done" }
+func (c FineTuningJobEvent) Default() FineTuningJobEvent { return "fine_tuning.job.event" }
+func (c FineTuningJobFailed) Default() FineTuningJobFailed { return "fine_tuning.job.failed" }
+func (c FineTuningJobSucceeded) Default() FineTuningJobSucceeded { return "fine_tuning.job.succeeded" }
+func (c Function) Default() Function { return "function" }
+func (c FunctionCall) Default() FunctionCall { return "function_call" }
+func (c FunctionCallOutput) Default() FunctionCallOutput { return "function_call_output" }
+func (c Image) Default() Image { return "image" }
+func (c ImageFile) Default() ImageFile { return "image_file" }
+func (c ImageGeneration) Default() ImageGeneration { return "image_generation" }
+func (c ImageGenerationCall) Default() ImageGenerationCall { return "image_generation_call" }
+func (c ImageURL) Default() ImageURL { return "image_url" }
+func (c InputAudio) Default() InputAudio { return "input_audio" }
+func (c InputFile) Default() InputFile { return "input_file" }
+func (c InputImage) Default() InputImage { return "input_image" }
+func (c InputText) Default() InputText { return "input_text" }
+func (c JSONObject) Default() JSONObject { return "json_object" }
+func (c JSONSchema) Default() JSONSchema { return "json_schema" }
+func (c Keypress) Default() Keypress { return "keypress" }
+func (c LabelModel) Default() LabelModel { return "label_model" }
+func (c LastActiveAt) Default() LastActiveAt { return "last_active_at" }
+func (c List) Default() List { return "list" }
+func (c LocalShell) Default() LocalShell { return "local_shell" }
+func (c LocalShellCall) Default() LocalShellCall { return "local_shell_call" }
+func (c LocalShellCallOutput) Default() LocalShellCallOutput { return "local_shell_call_output" }
+func (c Logs) Default() Logs { return "logs" }
+func (c Mcp) Default() Mcp { return "mcp" }
+func (c McpApprovalRequest) Default() McpApprovalRequest { return "mcp_approval_request" }
+func (c McpApprovalResponse) Default() McpApprovalResponse { return "mcp_approval_response" }
+func (c McpCall) Default() McpCall { return "mcp_call" }
+func (c McpListTools) Default() McpListTools { return "mcp_list_tools" }
+func (c Message) Default() Message { return "message" }
+func (c MessageCreation) Default() MessageCreation { return "message_creation" }
+func (c Model) Default() Model { return "model" }
+func (c Move) Default() Move { return "move" }
+func (c Multi) Default() Multi { return "multi" }
+func (c OpenPage) Default() OpenPage { return "open_page" }
+func (c Other) Default() Other { return "other" }
+func (c OutputAudio) Default() OutputAudio { return "output_audio" }
+func (c OutputText) Default() OutputText { return "output_text" }
+func (c Python) Default() Python { return "python" }
+func (c Reasoning) Default() Reasoning { return "reasoning" }
+func (c Refusal) Default() Refusal { return "refusal" }
+func (c Response) Default() Response { return "response" }
+func (c ResponseAudioDelta) Default() ResponseAudioDelta { return "response.audio.delta" }
+func (c ResponseAudioDone) Default() ResponseAudioDone { return "response.audio.done" }
func (c ResponseAudioTranscriptDelta) Default() ResponseAudioTranscriptDelta {
return "response.audio.transcript.delta"
}
func (c ResponseAudioTranscriptDone) Default() ResponseAudioTranscriptDone {
return "response.audio.transcript.done"
}
+func (c ResponseCancelled) Default() ResponseCancelled { return "response.cancelled" }
func (c ResponseCodeInterpreterCallCodeDelta) Default() ResponseCodeInterpreterCallCodeDelta {
- return "response.code_interpreter_call.code.delta"
+ return "response.code_interpreter_call_code.delta"
}
func (c ResponseCodeInterpreterCallCodeDone) Default() ResponseCodeInterpreterCallCodeDone {
- return "response.code_interpreter_call.code.done"
+ return "response.code_interpreter_call_code.done"
}
func (c ResponseCodeInterpreterCallCompleted) Default() ResponseCodeInterpreterCallCompleted {
return "response.code_interpreter_call.completed"
@@ -267,19 +351,74 @@ func (c ResponseFunctionCallArgumentsDelta) Default() ResponseFunctionCallArgume
func (c ResponseFunctionCallArgumentsDone) Default() ResponseFunctionCallArgumentsDone {
return "response.function_call_arguments.done"
}
+func (c ResponseImageGenerationCallCompleted) Default() ResponseImageGenerationCallCompleted {
+ return "response.image_generation_call.completed"
+}
+func (c ResponseImageGenerationCallGenerating) Default() ResponseImageGenerationCallGenerating {
+ return "response.image_generation_call.generating"
+}
+func (c ResponseImageGenerationCallInProgress) Default() ResponseImageGenerationCallInProgress {
+ return "response.image_generation_call.in_progress"
+}
+func (c ResponseImageGenerationCallPartialImage) Default() ResponseImageGenerationCallPartialImage {
+ return "response.image_generation_call.partial_image"
+}
func (c ResponseInProgress) Default() ResponseInProgress { return "response.in_progress" }
func (c ResponseIncomplete) Default() ResponseIncomplete { return "response.incomplete" }
+func (c ResponseMcpCallArgumentsDelta) Default() ResponseMcpCallArgumentsDelta {
+ return "response.mcp_call.arguments_delta"
+}
+func (c ResponseMcpCallArgumentsDone) Default() ResponseMcpCallArgumentsDone {
+ return "response.mcp_call.arguments_done"
+}
+func (c ResponseMcpCallCompleted) Default() ResponseMcpCallCompleted {
+ return "response.mcp_call.completed"
+}
+func (c ResponseMcpCallFailed) Default() ResponseMcpCallFailed { return "response.mcp_call.failed" }
+func (c ResponseMcpCallInProgress) Default() ResponseMcpCallInProgress {
+ return "response.mcp_call.in_progress"
+}
+func (c ResponseMcpListToolsCompleted) Default() ResponseMcpListToolsCompleted {
+ return "response.mcp_list_tools.completed"
+}
+func (c ResponseMcpListToolsFailed) Default() ResponseMcpListToolsFailed {
+ return "response.mcp_list_tools.failed"
+}
+func (c ResponseMcpListToolsInProgress) Default() ResponseMcpListToolsInProgress {
+ return "response.mcp_list_tools.in_progress"
+}
func (c ResponseOutputItemAdded) Default() ResponseOutputItemAdded {
return "response.output_item.added"
}
func (c ResponseOutputItemDone) Default() ResponseOutputItemDone { return "response.output_item.done" }
func (c ResponseOutputTextAnnotationAdded) Default() ResponseOutputTextAnnotationAdded {
- return "response.output_text.annotation.added"
+ return "response.output_text_annotation.added"
}
func (c ResponseOutputTextDelta) Default() ResponseOutputTextDelta {
return "response.output_text.delta"
}
func (c ResponseOutputTextDone) Default() ResponseOutputTextDone { return "response.output_text.done" }
+func (c ResponseQueued) Default() ResponseQueued { return "response.queued" }
+func (c ResponseReasoningSummaryPartAdded) Default() ResponseReasoningSummaryPartAdded {
+ return "response.reasoning_summary_part.added"
+}
+func (c ResponseReasoningSummaryPartDone) Default() ResponseReasoningSummaryPartDone {
+ return "response.reasoning_summary_part.done"
+}
+func (c ResponseReasoningSummaryTextDelta) Default() ResponseReasoningSummaryTextDelta {
+ return "response.reasoning_summary_text.delta"
+}
+func (c ResponseReasoningSummaryTextDone) Default() ResponseReasoningSummaryTextDone {
+ return "response.reasoning_summary_text.done"
+}
+func (c ResponseReasoningSummaryDelta) Default() ResponseReasoningSummaryDelta {
+ return "response.reasoning_summary.delta"
+}
+func (c ResponseReasoningSummaryDone) Default() ResponseReasoningSummaryDone {
+ return "response.reasoning_summary.done"
+}
+func (c ResponseReasoningDelta) Default() ResponseReasoningDelta { return "response.reasoning.delta" }
+func (c ResponseReasoningDone) Default() ResponseReasoningDone { return "response.reasoning.done" }
func (c ResponseRefusalDelta) Default() ResponseRefusalDelta { return "response.refusal.delta" }
func (c ResponseRefusalDone) Default() ResponseRefusalDone { return "response.refusal.done" }
func (c ResponseWebSearchCallCompleted) Default() ResponseWebSearchCallCompleted {
@@ -291,14 +430,18 @@ func (c ResponseWebSearchCallInProgress) Default() ResponseWebSearchCallInProgre
func (c ResponseWebSearchCallSearching) Default() ResponseWebSearchCallSearching {
return "response.web_search_call.searching"
}
+func (c ScoreModel) Default() ScoreModel { return "score_model" }
func (c Screenshot) Default() Screenshot { return "screenshot" }
func (c Scroll) Default() Scroll { return "scroll" }
+func (c Search) Default() Search { return "search" }
func (c Static) Default() Static { return "static" }
+func (c StringCheck) Default() StringCheck { return "string_check" }
func (c SubmitToolOutputs) Default() SubmitToolOutputs { return "submit_tool_outputs" }
func (c SummaryText) Default() SummaryText { return "summary_text" }
func (c System) Default() System { return "system" }
func (c Text) Default() Text { return "text" }
func (c TextCompletion) Default() TextCompletion { return "text_completion" }
+func (c TextSimilarity) Default() TextSimilarity { return "text_similarity" }
func (c Thread) Default() Thread { return "thread" }
func (c ThreadCreated) Default() ThreadCreated { return "thread.created" }
func (c ThreadDeleted) Default() ThreadDeleted { return "thread.deleted" }
@@ -336,6 +479,7 @@ func (c ThreadRunStepFailed) Default() ThreadRunStepFailed { return "threa
func (c ThreadRunStepInProgress) Default() ThreadRunStepInProgress {
return "thread.run.step.in_progress"
}
+func (c Tokens) Default() Tokens { return "tokens" }
func (c Tool) Default() Tool { return "tool" }
func (c ToolCalls) Default() ToolCalls { return "tool_calls" }
func (c TranscriptTextDelta) Default() TranscriptTextDelta { return "transcript.text.delta" }
@@ -365,54 +509,82 @@ func (c Assistant) MarshalJSON() ([]byte, error) { r
func (c AssistantDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Auto) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Batch) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c BatchCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c BatchCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c BatchExpired) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c BatchFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ChatCompletion) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ChatCompletionChunk) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ChatCompletionDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c CheckpointPermission) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Click) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c CodeInterpreter) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c CodeInterpreterCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ComputerCallOutput) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ComputerScreenshot) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ComputerUsePreview) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ContainerFileCitation) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ContainerFile) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Content) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Developer) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c DoubleClick) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Drag) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Duration) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Embedding) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Error) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c EvalRunCanceled) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c EvalRunFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c EvalRunSucceeded) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Exec) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c File) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FileCitation) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FilePath) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FileSearch) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FileSearchCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Files) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Find) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FineTuningJob) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c FineTuningJobCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FineTuningJobCheckpoint) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FineTuningJobEvent) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c FineTuningJobFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c FineTuningJobSucceeded) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Function) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FunctionCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c FunctionCallOutput) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Image) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ImageFile) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ImageGeneration) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ImageGenerationCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ImageURL) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c InputAudio) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c InputFile) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c InputImage) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c InputText) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ItemReference) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c JSONObject) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c JSONSchema) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Keypress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c LabelModel) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c LastActiveAt) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c List) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c LocalShell) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c LocalShellCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c LocalShellCallOutput) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Logs) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Mcp) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c McpApprovalRequest) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c McpApprovalResponse) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c McpCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c McpListTools) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Message) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c MessageCreation) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Model) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Move) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Multi) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c OpenPage) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Other) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c OutputAudio) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c OutputText) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Python) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Reasoning) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Refusal) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c Response) MarshalJSON() ([]byte, error) { return marshalString(c) }
@@ -420,6 +592,7 @@ func (c ResponseAudioDelta) MarshalJSON() ([]byte, error) { r
func (c ResponseAudioDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ResponseAudioTranscriptDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ResponseAudioTranscriptDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ResponseCodeInterpreterCallCodeDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ResponseCodeInterpreterCallCodeDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
func (c ResponseCodeInterpreterCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
@@ -427,84 +600,112 @@ func (c ResponseCodeInterpreterCallInProgress) MarshalJSON() ([]byte, error) { r
func (c ResponseCodeInterpreterCallInterpreting) MarshalJSON() ([]byte, error) {
return marshalString(c)
}
-func (c ResponseCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseContentPartAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseContentPartDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFileSearchCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFileSearchCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFileSearchCallSearching) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFunctionCallArgumentsDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseFunctionCallArgumentsDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseOutputItemAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseOutputItemDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseOutputTextAnnotationAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseOutputTextDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseOutputTextDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseRefusalDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseRefusalDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseWebSearchCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseWebSearchCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ResponseWebSearchCallSearching) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Screenshot) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Scroll) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Static) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c SubmitToolOutputs) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c SummaryText) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c System) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Text) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c TextCompletion) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Thread) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessage) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadMessageIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRun) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunCancelling) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunExpired) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunQueued) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunRequiresAction) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStep) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepExpired) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ThreadRunStepInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Tool) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c ToolCalls) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c TranscriptTextDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c TranscriptTextDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Type) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Upload) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c UploadPart) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c URLCitation) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c User) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStore) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreFile) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreFileContentPage) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreFileDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreFilesBatch) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c VectorStoreSearchResultsPage) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Wait) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c Wandb) MarshalJSON() ([]byte, error) { return marshalString(c) }
-func (c WebSearchCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseContentPartAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseContentPartDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFileSearchCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFileSearchCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFileSearchCallSearching) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFunctionCallArgumentsDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseFunctionCallArgumentsDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseImageGenerationCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseImageGenerationCallGenerating) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseImageGenerationCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseImageGenerationCallPartialImage) MarshalJSON() ([]byte, error) {
+ return marshalString(c)
+}
+func (c ResponseInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpCallArgumentsDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpCallArgumentsDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpCallFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpListToolsCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpListToolsFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseMcpListToolsInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseOutputItemAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseOutputItemDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseOutputTextAnnotationAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseOutputTextDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseOutputTextDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseQueued) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryPartAdded) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryPartDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryTextDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryTextDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningSummaryDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseReasoningDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseRefusalDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseRefusalDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseWebSearchCallCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseWebSearchCallInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ResponseWebSearchCallSearching) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ScoreModel) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Screenshot) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Scroll) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Search) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Static) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c StringCheck) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c SubmitToolOutputs) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c SummaryText) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c System) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Text) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c TextCompletion) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c TextSimilarity) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Thread) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessage) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadMessageIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRun) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunCancelling) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunExpired) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunIncomplete) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunQueued) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunRequiresAction) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStep) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepCancelled) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepCompleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepCreated) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepExpired) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepFailed) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ThreadRunStepInProgress) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Tokens) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Tool) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c ToolCalls) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c TranscriptTextDelta) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c TranscriptTextDone) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Type) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Upload) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c UploadPart) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c URLCitation) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c User) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStore) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreFile) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreFileContentPage) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreFileDeleted) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreFilesBatch) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c VectorStoreSearchResultsPage) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Wait) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c Wandb) MarshalJSON() ([]byte, error) { return marshalString(c) }
+func (c WebSearchCall) MarshalJSON() ([]byte, error) { return marshalString(c) }
type constant[T any] interface {
Constant[T]
@@ -7,19 +7,30 @@ import (
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
// aliased to make [param.APIUnion] private when embedding
type paramUnion = param.APIUnion
-
-// aliased to make [param.APIObject] private when embedding
type paramObj = param.APIObject
type ChatModel = string
+type ResponsesModel = string
+
+// aliased to make [param.APIObject] private when embedding
const (
+ ChatModelGPT4_1 ChatModel = "gpt-4.1"
+ ChatModelGPT4_1Mini ChatModel = "gpt-4.1-mini"
+ ChatModelGPT4_1Nano ChatModel = "gpt-4.1-nano"
+ ChatModelGPT4_1_2025_04_14 ChatModel = "gpt-4.1-2025-04-14"
+ ChatModelGPT4_1Mini2025_04_14 ChatModel = "gpt-4.1-mini-2025-04-14"
+ ChatModelGPT4_1Nano2025_04_14 ChatModel = "gpt-4.1-nano-2025-04-14"
+ ChatModelO4Mini ChatModel = "o4-mini"
+ ChatModelO4Mini2025_04_16 ChatModel = "o4-mini-2025-04-16"
+ ChatModelO3 ChatModel = "o3"
+ ChatModelO3_2025_04_16 ChatModel = "o3-2025-04-16"
ChatModelO3Mini ChatModel = "o3-mini"
ChatModelO3Mini2025_01_31 ChatModel = "o3-mini-2025-01-31"
ChatModelO1 ChatModel = "o1"
@@ -35,6 +46,7 @@ const (
ChatModelGPT4oAudioPreview ChatModel = "gpt-4o-audio-preview"
ChatModelGPT4oAudioPreview2024_10_01 ChatModel = "gpt-4o-audio-preview-2024-10-01"
ChatModelGPT4oAudioPreview2024_12_17 ChatModel = "gpt-4o-audio-preview-2024-12-17"
+ ChatModelGPT4oAudioPreview2025_06_03 ChatModel = "gpt-4o-audio-preview-2025-06-03"
ChatModelGPT4oMiniAudioPreview ChatModel = "gpt-4o-mini-audio-preview"
ChatModelGPT4oMiniAudioPreview2024_12_17 ChatModel = "gpt-4o-mini-audio-preview-2024-12-17"
ChatModelGPT4oSearchPreview ChatModel = "gpt-4o-search-preview"
@@ -42,6 +54,7 @@ const (
ChatModelGPT4oSearchPreview2025_03_11 ChatModel = "gpt-4o-search-preview-2025-03-11"
ChatModelGPT4oMiniSearchPreview2025_03_11 ChatModel = "gpt-4o-mini-search-preview-2025-03-11"
ChatModelChatgpt4oLatest ChatModel = "chatgpt-4o-latest"
+ ChatModelCodexMiniLatest ChatModel = "codex-mini-latest"
ChatModelGPT4oMini ChatModel = "gpt-4o-mini"
ChatModelGPT4oMini2024_07_18 ChatModel = "gpt-4o-mini-2024-07-18"
ChatModelGPT4Turbo ChatModel = "gpt-4-turbo"
@@ -84,13 +97,12 @@ type ComparisonFilter struct {
// The value to compare against the attribute key; supports string, number, or
// boolean types.
Value ComparisonFilterValueUnion `json:"value,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Key resp.Field
- Type resp.Field
- Value resp.Field
- ExtraFields map[string]resp.Field
+ Key respjson.Field
+ Type respjson.Field
+ Value respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -105,9 +117,9 @@ func (r *ComparisonFilter) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ComparisonFilterParam.IsOverridden()
+// ComparisonFilterParam.Overrides()
func (r ComparisonFilter) ToParam() ComparisonFilterParam {
- return param.OverrideObj[ComparisonFilterParam](r.RawJSON())
+ return param.Override[ComparisonFilterParam](json.RawMessage(r.RawJSON()))
}
// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
@@ -144,9 +156,9 @@ type ComparisonFilterValueUnion struct {
// This field will be present if the value is a [bool] instead of an object.
OfBool bool `json:",inline"`
JSON struct {
- OfString resp.Field
- OfFloat resp.Field
- OfBool resp.Field
+ OfString respjson.Field
+ OfFloat respjson.Field
+ OfBool respjson.Field
raw string
} `json:"-"`
}
@@ -197,13 +209,13 @@ type ComparisonFilterParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ComparisonFilterParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ComparisonFilterParam) MarshalJSON() (data []byte, err error) {
type shadow ComparisonFilterParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ComparisonFilterParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -215,11 +227,11 @@ type ComparisonFilterValueUnionParam struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u ComparisonFilterValueUnionParam) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u ComparisonFilterValueUnionParam) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[ComparisonFilterValueUnionParam](u.OfString, u.OfFloat, u.OfBool)
+ return param.MarshalUnion(u, u.OfString, u.OfFloat, u.OfBool)
+}
+func (u *ComparisonFilterValueUnionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *ComparisonFilterValueUnionParam) asAny() any {
@@ -242,12 +254,11 @@ type CompoundFilter struct {
//
// Any of "and", "or".
Type CompoundFilterType `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Filters resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Filters respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -262,9 +273,9 @@ func (r *CompoundFilter) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// CompoundFilterParam.IsOverridden()
+// CompoundFilterParam.Overrides()
func (r CompoundFilter) ToParam() CompoundFilterParam {
- return param.OverrideObj[CompoundFilterParam](r.RawJSON())
+ return param.Override[CompoundFilterParam](json.RawMessage(r.RawJSON()))
}
// Type of operation: `and` or `or`.
@@ -289,27 +300,26 @@ type CompoundFilterParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f CompoundFilterParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r CompoundFilterParam) MarshalJSON() (data []byte, err error) {
type shadow CompoundFilterParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *CompoundFilterParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type ErrorObject struct {
Code string `json:"code,required"`
Message string `json:"message,required"`
Param string `json:"param,required"`
Type string `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- Param resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ Param respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -341,14 +351,13 @@ type FunctionDefinition struct {
// more about Structured Outputs in the
// [function calling guide](docs/guides/function-calling).
Strict bool `json:"strict,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Name resp.Field
- Description resp.Field
- Parameters resp.Field
- Strict resp.Field
- ExtraFields map[string]resp.Field
+ Name respjson.Field
+ Description respjson.Field
+ Parameters respjson.Field
+ Strict respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -363,9 +372,9 @@ func (r *FunctionDefinition) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// FunctionDefinitionParam.IsOverridden()
+// FunctionDefinitionParam.Overrides()
func (r FunctionDefinition) ToParam() FunctionDefinitionParam {
- return param.OverrideObj[FunctionDefinitionParam](r.RawJSON())
+ return param.Override[FunctionDefinitionParam](json.RawMessage(r.RawJSON()))
}
// The property Name is required.
@@ -393,20 +402,18 @@ type FunctionDefinitionParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f FunctionDefinitionParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r FunctionDefinitionParam) MarshalJSON() (data []byte, err error) {
type shadow FunctionDefinitionParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *FunctionDefinitionParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
-type FunctionParameters map[string]interface{}
+type FunctionParameters map[string]any
type Metadata map[string]string
-type MetadataParam map[string]string
-
// **o-series models only**
//
// Configuration options for
@@ -421,20 +428,28 @@ type Reasoning struct {
//
// Any of "low", "medium", "high".
Effort ReasoningEffort `json:"effort,nullable"`
- // **computer_use_preview only**
+ // **Deprecated:** use `summary` instead.
//
// A summary of the reasoning performed by the model. This can be useful for
- // debugging and understanding the model's reasoning process. One of `concise` or
- // `detailed`.
+ // debugging and understanding the model's reasoning process. One of `auto`,
+ // `concise`, or `detailed`.
+ //
+ // Any of "auto", "concise", "detailed".
//
- // Any of "concise", "detailed".
+ // Deprecated: deprecated
GenerateSummary ReasoningGenerateSummary `json:"generate_summary,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // A summary of the reasoning performed by the model. This can be useful for
+ // debugging and understanding the model's reasoning process. One of `auto`,
+ // `concise`, or `detailed`.
+ //
+ // Any of "auto", "concise", "detailed".
+ Summary ReasoningSummary `json:"summary,nullable"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Effort resp.Field
- GenerateSummary resp.Field
- ExtraFields map[string]resp.Field
+ Effort respjson.Field
+ GenerateSummary respjson.Field
+ Summary respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -449,23 +464,35 @@ func (r *Reasoning) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ReasoningParam.IsOverridden()
+// ReasoningParam.Overrides()
func (r Reasoning) ToParam() ReasoningParam {
- return param.OverrideObj[ReasoningParam](r.RawJSON())
+ return param.Override[ReasoningParam](json.RawMessage(r.RawJSON()))
}
-// **computer_use_preview only**
+// **Deprecated:** use `summary` instead.
//
// A summary of the reasoning performed by the model. This can be useful for
-// debugging and understanding the model's reasoning process. One of `concise` or
-// `detailed`.
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
type ReasoningGenerateSummary string
const (
+ ReasoningGenerateSummaryAuto ReasoningGenerateSummary = "auto"
ReasoningGenerateSummaryConcise ReasoningGenerateSummary = "concise"
ReasoningGenerateSummaryDetailed ReasoningGenerateSummary = "detailed"
)
+// A summary of the reasoning performed by the model. This can be useful for
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
+type ReasoningSummary string
+
+const (
+ ReasoningSummaryAuto ReasoningSummary = "auto"
+ ReasoningSummaryConcise ReasoningSummary = "concise"
+ ReasoningSummaryDetailed ReasoningSummary = "detailed"
+)
+
// **o-series models only**
//
// Configuration options for
@@ -480,24 +507,32 @@ type ReasoningParam struct {
//
// Any of "low", "medium", "high".
Effort ReasoningEffort `json:"effort,omitzero"`
- // **computer_use_preview only**
+ // **Deprecated:** use `summary` instead.
//
// A summary of the reasoning performed by the model. This can be useful for
- // debugging and understanding the model's reasoning process. One of `concise` or
- // `detailed`.
+ // debugging and understanding the model's reasoning process. One of `auto`,
+ // `concise`, or `detailed`.
//
- // Any of "concise", "detailed".
+ // Any of "auto", "concise", "detailed".
+ //
+ // Deprecated: deprecated
GenerateSummary ReasoningGenerateSummary `json:"generate_summary,omitzero"`
+ // A summary of the reasoning performed by the model. This can be useful for
+ // debugging and understanding the model's reasoning process. One of `auto`,
+ // `concise`, or `detailed`.
+ //
+ // Any of "auto", "concise", "detailed".
+ Summary ReasoningSummary `json:"summary,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ReasoningParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ReasoningParam) MarshalJSON() (data []byte, err error) {
type shadow ReasoningParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ReasoningParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// **o-series models only**
//
@@ -519,11 +554,10 @@ const (
type ResponseFormatJSONObject struct {
// The type of response format being defined. Always `json_object`.
Type constant.JSONObject `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -534,36 +568,43 @@ func (r *ResponseFormatJSONObject) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+func (ResponseFormatJSONObject) ImplResponseFormatTextConfigUnion() {}
+
// ToParam converts this ResponseFormatJSONObject to a
// ResponseFormatJSONObjectParam.
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ResponseFormatJSONObjectParam.IsOverridden()
+// ResponseFormatJSONObjectParam.Overrides()
func (r ResponseFormatJSONObject) ToParam() ResponseFormatJSONObjectParam {
- return param.OverrideObj[ResponseFormatJSONObjectParam](r.RawJSON())
+ return param.Override[ResponseFormatJSONObjectParam](json.RawMessage(r.RawJSON()))
+}
+
+func NewResponseFormatJSONObjectParam() ResponseFormatJSONObjectParam {
+ return ResponseFormatJSONObjectParam{
+ Type: "json_object",
+ }
}
// JSON object response format. An older method of generating JSON responses. Using
// `json_schema` is recommended for models that support it. Note that the model
// will not generate JSON without a system or user message instructing it to do so.
//
-// The property Type is required.
+// This struct has a constant value, construct it with
+// [NewResponseFormatJSONObjectParam].
type ResponseFormatJSONObjectParam struct {
// The type of response format being defined. Always `json_object`.
- //
- // This field can be elided, and will marshal its zero value as "json_object".
Type constant.JSONObject `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ResponseFormatJSONObjectParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ResponseFormatJSONObjectParam) MarshalJSON() (data []byte, err error) {
type shadow ResponseFormatJSONObjectParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ResponseFormatJSONObjectParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// JSON Schema response format. Used to generate structured JSON responses. Learn
// more about
@@ -573,12 +614,11 @@ type ResponseFormatJSONSchema struct {
JSONSchema ResponseFormatJSONSchemaJSONSchema `json:"json_schema,required"`
// The type of response format being defined. Always `json_schema`.
Type constant.JSONSchema `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- JSONSchema resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ JSONSchema respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -594,9 +634,9 @@ func (r *ResponseFormatJSONSchema) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ResponseFormatJSONSchemaParam.IsOverridden()
+// ResponseFormatJSONSchemaParam.Overrides()
func (r ResponseFormatJSONSchema) ToParam() ResponseFormatJSONSchemaParam {
- return param.OverrideObj[ResponseFormatJSONSchemaParam](r.RawJSON())
+ return param.Override[ResponseFormatJSONSchemaParam](json.RawMessage(r.RawJSON()))
}
// Structured Outputs configuration options, including a JSON Schema.
@@ -609,21 +649,20 @@ type ResponseFormatJSONSchemaJSONSchema struct {
Description string `json:"description"`
// The schema for the response format, described as a JSON Schema object. Learn how
// to build JSON schemas [here](https://json-schema.org/).
- Schema map[string]interface{} `json:"schema"`
+ Schema map[string]any `json:"schema"`
// Whether to enable strict schema adherence when generating the output. If set to
// true, the model will always follow the exact schema defined in the `schema`
// field. Only a subset of JSON Schema is supported when `strict` is `true`. To
// learn more, read the
// [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
Strict bool `json:"strict,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Name resp.Field
- Description resp.Field
- Schema resp.Field
- Strict resp.Field
- ExtraFields map[string]resp.Field
+ Name respjson.Field
+ Description respjson.Field
+ Schema respjson.Field
+ Strict respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -649,13 +688,13 @@ type ResponseFormatJSONSchemaParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ResponseFormatJSONSchemaParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ResponseFormatJSONSchemaParam) MarshalJSON() (data []byte, err error) {
type shadow ResponseFormatJSONSchemaParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ResponseFormatJSONSchemaParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Structured Outputs configuration options, including a JSON Schema.
//
@@ -675,29 +714,26 @@ type ResponseFormatJSONSchemaJSONSchemaParam struct {
Description param.Opt[string] `json:"description,omitzero"`
// The schema for the response format, described as a JSON Schema object. Learn how
// to build JSON schemas [here](https://json-schema.org/).
- Schema interface{} `json:"schema,omitzero"`
+ Schema any `json:"schema,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ResponseFormatJSONSchemaJSONSchemaParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r ResponseFormatJSONSchemaJSONSchemaParam) MarshalJSON() (data []byte, err error) {
type shadow ResponseFormatJSONSchemaJSONSchemaParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ResponseFormatJSONSchemaJSONSchemaParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Default response format. Used to generate text responses.
type ResponseFormatText struct {
// The type of response format being defined. Always `text`.
Type constant.Text `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -708,40 +744,52 @@ func (r *ResponseFormatText) UnmarshalJSON(data []byte) error {
return apijson.UnmarshalRoot(data, r)
}
+func (ResponseFormatText) ImplResponseFormatTextConfigUnion() {}
+
// ToParam converts this ResponseFormatText to a ResponseFormatTextParam.
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// ResponseFormatTextParam.IsOverridden()
+// ResponseFormatTextParam.Overrides()
func (r ResponseFormatText) ToParam() ResponseFormatTextParam {
- return param.OverrideObj[ResponseFormatTextParam](r.RawJSON())
+ return param.Override[ResponseFormatTextParam](json.RawMessage(r.RawJSON()))
+}
+
+func NewResponseFormatTextParam() ResponseFormatTextParam {
+ return ResponseFormatTextParam{
+ Type: "text",
+ }
}
// Default response format. Used to generate text responses.
//
-// The property Type is required.
+// This struct has a constant value, construct it with
+// [NewResponseFormatTextParam].
type ResponseFormatTextParam struct {
// The type of response format being defined. Always `text`.
- //
- // This field can be elided, and will marshal its zero value as "text".
Type constant.Text `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f ResponseFormatTextParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r ResponseFormatTextParam) MarshalJSON() (data []byte, err error) {
type shadow ResponseFormatTextParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *ResponseFormatTextParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// ResponsesModel also accepts any [string] or [ChatModel]
-type ResponsesModel = string
const (
ResponsesModelO1Pro ResponsesModel = "o1-pro"
ResponsesModelO1Pro2025_03_19 ResponsesModel = "o1-pro-2025-03-19"
+ ResponsesModelO3Pro ResponsesModel = "o3-pro"
+ ResponsesModelO3Pro2025_06_10 ResponsesModel = "o3-pro-2025-06-10"
+ ResponsesModelO3DeepResearch ResponsesModel = "o3-deep-research"
+ ResponsesModelO3DeepResearch2025_06_26 ResponsesModel = "o3-deep-research-2025-06-26"
+ ResponsesModelO4MiniDeepResearch ResponsesModel = "o4-mini-deep-research"
+ ResponsesModelO4MiniDeepResearch2025_06_26 ResponsesModel = "o4-mini-deep-research-2025-06-26"
ResponsesModelComputerUsePreview ResponsesModel = "computer-use-preview"
ResponsesModelComputerUsePreview2025_03_11 ResponsesModel = "computer-use-preview-2025-03-11"
// Or some ...[ChatModel]
@@ -13,7 +13,7 @@ type ChatCompletionAccumulator struct {
type FinishedChatCompletionToolCall struct {
ChatCompletionMessageToolCallFunction
Index int
- Id string
+ ID string
}
type chatCompletionResponseState struct {
@@ -52,7 +52,7 @@ func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool {
return true
}
-// JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed.
+// JustFinishedContent retrieves the chat completion content when it is known to have just been completed.
// The content is "just completed" when the last added chunk no longer contains a content
// delta. If the content is just completed, the content is returned and the boolean is true. Otherwise,
// an empty string is returned and the boolean will be false.
@@ -86,7 +86,7 @@ func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedC
f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function
id := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].ID
return FinishedChatCompletionToolCall{
- Id: id,
+ ID: id,
Index: acc.justFinished.index,
ChatCompletionMessageToolCallFunction: ChatCompletionMessageToolCallFunction{
Name: f.Name,
@@ -163,11 +163,11 @@ func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (just
delta := chunk.Choices[0].Delta
new := chatCompletionResponseState{}
switch {
- case delta.JSON.Content.IsPresent():
+ case delta.JSON.Content.Valid():
new.state = contentResponseState
- case delta.JSON.Refusal.IsPresent():
+ case delta.JSON.Refusal.Valid():
new.state = refusalResponseState
- case delta.JSON.ToolCalls.IsPresent():
+ case delta.JSON.ToolCalls.Valid():
new.state = toolResponseState
new.index = int(delta.ToolCalls[0].Index)
default:
@@ -12,7 +12,7 @@ import (
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -123,19 +123,18 @@ type Upload struct {
Status UploadStatus `json:"status,required"`
// The `File` object represents a document that has been uploaded to OpenAI.
File FileObject `json:"file,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Bytes resp.Field
- CreatedAt resp.Field
- ExpiresAt resp.Field
- Filename resp.Field
- Object resp.Field
- Purpose resp.Field
- Status resp.Field
- File resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Bytes respjson.Field
+ CreatedAt respjson.Field
+ ExpiresAt respjson.Field
+ Filename respjson.Field
+ Object respjson.Field
+ Purpose respjson.Field
+ Status respjson.Field
+ File respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -176,14 +175,13 @@ type UploadNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f UploadNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r UploadNewParams) MarshalJSON() (data []byte, err error) {
type shadow UploadNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *UploadNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type UploadCompleteParams struct {
// The ordered list of Part IDs.
@@ -194,11 +192,10 @@ type UploadCompleteParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f UploadCompleteParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r UploadCompleteParams) MarshalJSON() (data []byte, err error) {
type shadow UploadCompleteParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *UploadCompleteParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
@@ -15,8 +15,7 @@ import (
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/requestconfig"
"github.com/openai/openai-go/option"
- "github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -71,14 +70,13 @@ type UploadPart struct {
Object constant.UploadPart `json:"object,required"`
// The ID of the Upload object that this Part was added to.
UploadID string `json:"upload_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- Object resp.Field
- UploadID resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Object respjson.Field
+ UploadID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -91,18 +89,17 @@ func (r *UploadPart) UnmarshalJSON(data []byte) error {
type UploadPartNewParams struct {
// The chunk of bytes for this Part.
- Data io.Reader `json:"data,required" format:"binary"`
+ Data io.Reader `json:"data,omitzero,required" format:"binary"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f UploadPartNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r UploadPartNewParams) MarshalMultipart() (data []byte, contentType string, err error) {
buf := bytes.NewBuffer(nil)
writer := multipart.NewWriter(buf)
err = apiform.MarshalRoot(r, writer)
+ if err == nil {
+ err = apiform.WriteExtras(writer, r.ExtraFields())
+ }
if err != nil {
writer.Close()
return nil, "", err
@@ -9,7 +9,6 @@ import (
"fmt"
"net/http"
"net/url"
- "reflect"
"github.com/openai/openai-go/internal/apijson"
"github.com/openai/openai-go/internal/apiquery"
@@ -17,10 +16,9 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared"
"github.com/openai/openai-go/shared/constant"
- "github.com/tidwall/gjson"
)
// VectorStoreService contains methods and other services that help with
@@ -146,25 +144,30 @@ func (r *VectorStoreService) SearchAutoPaging(ctx context.Context, vectorStoreID
return pagination.NewPageAutoPager(r.Search(ctx, vectorStoreID, body, opts...))
}
+func NewAutoFileChunkingStrategyParam() AutoFileChunkingStrategyParam {
+ return AutoFileChunkingStrategyParam{
+ Type: "auto",
+ }
+}
+
// The default strategy. This strategy currently uses a `max_chunk_size_tokens` of
// `800` and `chunk_overlap_tokens` of `400`.
//
-// The property Type is required.
+// This struct has a constant value, construct it with
+// [NewAutoFileChunkingStrategyParam].
type AutoFileChunkingStrategyParam struct {
// Always `auto`.
- //
- // This field can be elided, and will marshal its zero value as "auto".
Type constant.Auto `json:"type,required"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f AutoFileChunkingStrategyParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r AutoFileChunkingStrategyParam) MarshalJSON() (data []byte, err error) {
type shadow AutoFileChunkingStrategyParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *AutoFileChunkingStrategyParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// FileChunkingStrategyUnion contains all possible properties and values from
// [StaticFileChunkingStrategyObject], [OtherFileChunkingStrategyObject].
@@ -178,21 +181,31 @@ type FileChunkingStrategyUnion struct {
// Any of "static", "other".
Type string `json:"type"`
JSON struct {
- Static resp.Field
- Type resp.Field
+ Static respjson.Field
+ Type respjson.Field
raw string
} `json:"-"`
}
+// anyFileChunkingStrategy is implemented by each variant of
+// [FileChunkingStrategyUnion] to add type safety for the return type of
+// [FileChunkingStrategyUnion.AsAny]
+type anyFileChunkingStrategy interface {
+ implFileChunkingStrategyUnion()
+}
+
+func (StaticFileChunkingStrategyObject) implFileChunkingStrategyUnion() {}
+func (OtherFileChunkingStrategyObject) implFileChunkingStrategyUnion() {}
+
// Use the following switch statement to find the correct variant
//
// switch variant := FileChunkingStrategyUnion.AsAny().(type) {
-// case StaticFileChunkingStrategyObject:
-// case OtherFileChunkingStrategyObject:
+// case openai.StaticFileChunkingStrategyObject:
+// case openai.OtherFileChunkingStrategyObject:
// default:
// fmt.Errorf("no variant present")
// }
-func (u FileChunkingStrategyUnion) AsAny() any {
+func (u FileChunkingStrategyUnion) AsAny() anyFileChunkingStrategy {
switch u.Type {
case "static":
return u.AsStatic()
@@ -234,11 +247,11 @@ type FileChunkingStrategyParamUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u FileChunkingStrategyParamUnion) IsPresent() bool { return !param.IsOmitted(u) && !u.IsNull() }
func (u FileChunkingStrategyParamUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[FileChunkingStrategyParamUnion](u.OfAuto, u.OfStatic)
+ return param.MarshalUnion(u, u.OfAuto, u.OfStatic)
+}
+func (u *FileChunkingStrategyParamUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *FileChunkingStrategyParamUnion) asAny() any {
@@ -271,16 +284,8 @@ func (u FileChunkingStrategyParamUnion) GetType() *string {
func init() {
apijson.RegisterUnion[FileChunkingStrategyParamUnion](
"type",
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(AutoFileChunkingStrategyParam{}),
- DiscriminatorValue: "auto",
- },
- apijson.UnionVariant{
- TypeFilter: gjson.JSON,
- Type: reflect.TypeOf(StaticFileChunkingStrategyObjectParam{}),
- DiscriminatorValue: "static",
- },
+ apijson.Discriminator[AutoFileChunkingStrategyParam]("auto"),
+ apijson.Discriminator[StaticFileChunkingStrategyObjectParam]("static"),
)
}
@@ -290,11 +295,10 @@ func init() {
type OtherFileChunkingStrategyObject struct {
// Always `other`.
Type constant.Other `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -313,12 +317,11 @@ type StaticFileChunkingStrategy struct {
// The maximum number of tokens in each chunk. The default value is `800`. The
// minimum value is `100` and the maximum value is `4096`.
MaxChunkSizeTokens int64 `json:"max_chunk_size_tokens,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ChunkOverlapTokens resp.Field
- MaxChunkSizeTokens resp.Field
- ExtraFields map[string]resp.Field
+ ChunkOverlapTokens respjson.Field
+ MaxChunkSizeTokens respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -334,9 +337,9 @@ func (r *StaticFileChunkingStrategy) UnmarshalJSON(data []byte) error {
//
// Warning: the fields of the param type will not be present. ToParam should only
// be used at the last possible moment before sending a request. Test for this with
-// StaticFileChunkingStrategyParam.IsOverridden()
+// StaticFileChunkingStrategyParam.Overrides()
func (r StaticFileChunkingStrategy) ToParam() StaticFileChunkingStrategyParam {
- return param.OverrideObj[StaticFileChunkingStrategyParam](r.RawJSON())
+ return param.Override[StaticFileChunkingStrategyParam](json.RawMessage(r.RawJSON()))
}
// The properties ChunkOverlapTokens, MaxChunkSizeTokens are required.
@@ -351,24 +354,23 @@ type StaticFileChunkingStrategyParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f StaticFileChunkingStrategyParam) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r StaticFileChunkingStrategyParam) MarshalJSON() (data []byte, err error) {
type shadow StaticFileChunkingStrategyParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *StaticFileChunkingStrategyParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type StaticFileChunkingStrategyObject struct {
Static StaticFileChunkingStrategy `json:"static,required"`
// Always `static`.
Type constant.Static `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Static resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Static respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -391,15 +393,13 @@ type StaticFileChunkingStrategyObjectParam struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f StaticFileChunkingStrategyObjectParam) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r StaticFileChunkingStrategyObjectParam) MarshalJSON() (data []byte, err error) {
type shadow StaticFileChunkingStrategyObjectParam
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *StaticFileChunkingStrategyObjectParam) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// A vector store is a collection of processed files can be used by the
// `file_search` tool.
@@ -434,21 +434,20 @@ type VectorStore struct {
ExpiresAfter VectorStoreExpiresAfter `json:"expires_after"`
// The Unix timestamp (in seconds) for when the vector store will expire.
ExpiresAt int64 `json:"expires_at,nullable"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- FileCounts resp.Field
- LastActiveAt resp.Field
- Metadata resp.Field
- Name resp.Field
- Object resp.Field
- Status resp.Field
- UsageBytes resp.Field
- ExpiresAfter resp.Field
- ExpiresAt resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ FileCounts respjson.Field
+ LastActiveAt respjson.Field
+ Metadata respjson.Field
+ Name respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ UsageBytes respjson.Field
+ ExpiresAfter respjson.Field
+ ExpiresAt respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -470,15 +469,14 @@ type VectorStoreFileCounts struct {
InProgress int64 `json:"in_progress,required"`
// The total number of files.
Total int64 `json:"total,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Cancelled resp.Field
- Completed resp.Field
- Failed resp.Field
- InProgress resp.Field
- Total resp.Field
- ExtraFields map[string]resp.Field
+ Cancelled respjson.Field
+ Completed respjson.Field
+ Failed respjson.Field
+ InProgress respjson.Field
+ Total respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -507,12 +505,11 @@ type VectorStoreExpiresAfter struct {
Anchor constant.LastActiveAt `json:"anchor,required"`
// The number of days after the anchor time that the vector store will expire.
Days int64 `json:"days,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Anchor resp.Field
- Days resp.Field
- ExtraFields map[string]resp.Field
+ Anchor respjson.Field
+ Days respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -527,13 +524,12 @@ type VectorStoreDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.VectorStoreDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -559,15 +555,14 @@ type VectorStoreSearchResponse struct {
Filename string `json:"filename,required"`
// The similarity score for the result.
Score float64 `json:"score,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Attributes resp.Field
- Content resp.Field
- FileID resp.Field
- Filename resp.Field
- Score resp.Field
- ExtraFields map[string]resp.Field
+ Attributes respjson.Field
+ Content respjson.Field
+ FileID respjson.Field
+ Filename respjson.Field
+ Score respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -593,9 +588,9 @@ type VectorStoreSearchResponseAttributeUnion struct {
// This field will be present if the value is a [bool] instead of an object.
OfBool bool `json:",inline"`
JSON struct {
- OfString resp.Field
- OfFloat resp.Field
- OfBool resp.Field
+ OfString respjson.Field
+ OfFloat respjson.Field
+ OfBool respjson.Field
raw string
} `json:"-"`
}
@@ -629,12 +624,11 @@ type VectorStoreSearchResponseContent struct {
//
// Any of "text".
Type string `json:"type,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -654,7 +648,7 @@ type VectorStoreNewParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
// The chunking strategy used to chunk the file(s). If not set, will use the `auto`
// strategy. Only applicable if `file_ids` is non-empty.
ChunkingStrategy FileChunkingStrategyParamUnion `json:"chunking_strategy,omitzero"`
@@ -667,14 +661,13 @@ type VectorStoreNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreNewParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The expiration policy for a vector store.
//
@@ -690,13 +683,13 @@ type VectorStoreNewParamsExpiresAfter struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreNewParamsExpiresAfter) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
func (r VectorStoreNewParamsExpiresAfter) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreNewParamsExpiresAfter
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreNewParamsExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type VectorStoreUpdateParams struct {
// The name of the vector store.
@@ -709,18 +702,17 @@ type VectorStoreUpdateParams struct {
//
// Keys are strings with a maximum length of 64 characters. Values are strings with
// a maximum length of 512 characters.
- Metadata shared.MetadataParam `json:"metadata,omitzero"`
+ Metadata shared.Metadata `json:"metadata,omitzero"`
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreUpdateParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreUpdateParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreUpdateParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreUpdateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// The expiration policy for a vector store.
//
@@ -736,15 +728,13 @@ type VectorStoreUpdateParamsExpiresAfter struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreUpdateParamsExpiresAfter) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r VectorStoreUpdateParamsExpiresAfter) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreUpdateParamsExpiresAfter
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreUpdateParamsExpiresAfter) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
type VectorStoreListParams struct {
// A cursor for use in pagination. `after` is an object ID that defines your place
@@ -768,12 +758,8 @@ type VectorStoreListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [VectorStoreListParams]'s query parameters as `url.Values`.
-func (r VectorStoreListParams) URLQuery() (v url.Values) {
+func (r VectorStoreListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -804,38 +790,35 @@ type VectorStoreSearchParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreSearchParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreSearchParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreSearchParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreSearchParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
// Use [param.IsOmitted] to confirm if a field is set.
type VectorStoreSearchParamsQueryUnion struct {
- OfString param.Opt[string] `json:",omitzero,inline"`
- OfVectorStoreSearchsQueryArray []string `json:",omitzero,inline"`
+ OfString param.Opt[string] `json:",omitzero,inline"`
+ OfStringArray []string `json:",omitzero,inline"`
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u VectorStoreSearchParamsQueryUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u VectorStoreSearchParamsQueryUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[VectorStoreSearchParamsQueryUnion](u.OfString, u.OfVectorStoreSearchsQueryArray)
+ return param.MarshalUnion(u, u.OfString, u.OfStringArray)
+}
+func (u *VectorStoreSearchParamsQueryUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *VectorStoreSearchParamsQueryUnion) asAny() any {
if !param.IsOmitted(u.OfString) {
return &u.OfString.Value
- } else if !param.IsOmitted(u.OfVectorStoreSearchsQueryArray) {
- return &u.OfVectorStoreSearchsQueryArray
+ } else if !param.IsOmitted(u.OfStringArray) {
+ return &u.OfStringArray
}
return nil
}
@@ -849,13 +832,11 @@ type VectorStoreSearchParamsFiltersUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u VectorStoreSearchParamsFiltersUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u VectorStoreSearchParamsFiltersUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[VectorStoreSearchParamsFiltersUnion](u.OfComparisonFilter, u.OfCompoundFilter)
+ return param.MarshalUnion(u, u.OfComparisonFilter, u.OfCompoundFilter)
+}
+func (u *VectorStoreSearchParamsFiltersUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *VectorStoreSearchParamsFiltersUnion) asAny() any {
@@ -909,18 +890,16 @@ type VectorStoreSearchParamsRankingOptions struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreSearchParamsRankingOptions) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
func (r VectorStoreSearchParamsRankingOptions) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreSearchParamsRankingOptions
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreSearchParamsRankingOptions) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
func init() {
apijson.RegisterFieldValidator[VectorStoreSearchParamsRankingOptions](
- "Ranker", false, "auto", "default-2024-11-15",
+ "ranker", "auto", "default-2024-11-15",
)
}
@@ -16,7 +16,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -238,19 +238,18 @@ type VectorStoreFile struct {
Attributes map[string]VectorStoreFileAttributeUnion `json:"attributes,nullable"`
// The strategy used to chunk the file.
ChunkingStrategy FileChunkingStrategyUnion `json:"chunking_strategy"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- LastError resp.Field
- Object resp.Field
- Status resp.Field
- UsageBytes resp.Field
- VectorStoreID resp.Field
- Attributes resp.Field
- ChunkingStrategy resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ LastError respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ UsageBytes respjson.Field
+ VectorStoreID respjson.Field
+ Attributes respjson.Field
+ ChunkingStrategy respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -270,12 +269,11 @@ type VectorStoreFileLastError struct {
Code string `json:"code,required"`
// A human-readable description of the error.
Message string `json:"message,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Code resp.Field
- Message resp.Field
- ExtraFields map[string]resp.Field
+ Code respjson.Field
+ Message respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -313,9 +311,9 @@ type VectorStoreFileAttributeUnion struct {
// This field will be present if the value is a [bool] instead of an object.
OfBool bool `json:",inline"`
JSON struct {
- OfString resp.Field
- OfFloat resp.Field
- OfBool resp.Field
+ OfString respjson.Field
+ OfFloat respjson.Field
+ OfBool respjson.Field
raw string
} `json:"-"`
}
@@ -346,13 +344,12 @@ type VectorStoreFileDeleted struct {
ID string `json:"id,required"`
Deleted bool `json:"deleted,required"`
Object constant.VectorStoreFileDeleted `json:"object,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- Deleted resp.Field
- Object resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ Deleted respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -368,12 +365,11 @@ type VectorStoreFileContentResponse struct {
Text string `json:"text"`
// The content type (currently only `"text"`)
Type string `json:"type"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Text resp.Field
- Type resp.Field
- ExtraFields map[string]resp.Field
+ Text respjson.Field
+ Type respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -401,14 +397,13 @@ type VectorStoreFileNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreFileNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreFileNewParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreFileNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreFileNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -420,13 +415,11 @@ type VectorStoreFileNewParamsAttributeUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u VectorStoreFileNewParamsAttributeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u VectorStoreFileNewParamsAttributeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[VectorStoreFileNewParamsAttributeUnion](u.OfString, u.OfFloat, u.OfBool)
+ return param.MarshalUnion(u, u.OfString, u.OfFloat, u.OfBool)
+}
+func (u *VectorStoreFileNewParamsAttributeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *VectorStoreFileNewParamsAttributeUnion) asAny() any {
@@ -450,14 +443,13 @@ type VectorStoreFileUpdateParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreFileUpdateParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreFileUpdateParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreFileUpdateParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreFileUpdateParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -469,13 +461,11 @@ type VectorStoreFileUpdateParamsAttributeUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u VectorStoreFileUpdateParamsAttributeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u VectorStoreFileUpdateParamsAttributeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[VectorStoreFileUpdateParamsAttributeUnion](u.OfString, u.OfFloat, u.OfBool)
+ return param.MarshalUnion(u, u.OfString, u.OfFloat, u.OfBool)
+}
+func (u *VectorStoreFileUpdateParamsAttributeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *VectorStoreFileUpdateParamsAttributeUnion) asAny() any {
@@ -515,13 +505,9 @@ type VectorStoreFileListParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreFileListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
// URLQuery serializes [VectorStoreFileListParams]'s query parameters as
// `url.Values`.
-func (r VectorStoreFileListParams) URLQuery() (v url.Values) {
+func (r VectorStoreFileListParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -16,7 +16,7 @@ import (
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/pagination"
"github.com/openai/openai-go/packages/param"
- "github.com/openai/openai-go/packages/resp"
+ "github.com/openai/openai-go/packages/respjson"
"github.com/openai/openai-go/shared/constant"
)
@@ -197,16 +197,15 @@ type VectorStoreFileBatch struct {
// that the [File](https://platform.openai.com/docs/api-reference/files) is
// attached to.
VectorStoreID string `json:"vector_store_id,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- ID resp.Field
- CreatedAt resp.Field
- FileCounts resp.Field
- Object resp.Field
- Status resp.Field
- VectorStoreID resp.Field
- ExtraFields map[string]resp.Field
+ ID respjson.Field
+ CreatedAt respjson.Field
+ FileCounts respjson.Field
+ Object respjson.Field
+ Status respjson.Field
+ VectorStoreID respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -228,15 +227,14 @@ type VectorStoreFileBatchFileCounts struct {
InProgress int64 `json:"in_progress,required"`
// The total number of files.
Total int64 `json:"total,required"`
- // Metadata for the response, check the presence of optional fields with the
- // [resp.Field.IsPresent] method.
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
JSON struct {
- Cancelled resp.Field
- Completed resp.Field
- Failed resp.Field
- InProgress resp.Field
- Total resp.Field
- ExtraFields map[string]resp.Field
+ Cancelled respjson.Field
+ Completed respjson.Field
+ Failed respjson.Field
+ InProgress respjson.Field
+ Total respjson.Field
+ ExtraFields map[string]respjson.Field
raw string
} `json:"-"`
}
@@ -275,14 +273,13 @@ type VectorStoreFileBatchNewParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreFileBatchNewParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
-
func (r VectorStoreFileBatchNewParams) MarshalJSON() (data []byte, err error) {
type shadow VectorStoreFileBatchNewParams
return param.MarshalObject(r, (*shadow)(&r))
}
+func (r *VectorStoreFileBatchNewParams) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
// Only one field can be non-zero.
//
@@ -294,13 +291,11 @@ type VectorStoreFileBatchNewParamsAttributeUnion struct {
paramUnion
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (u VectorStoreFileBatchNewParamsAttributeUnion) IsPresent() bool {
- return !param.IsOmitted(u) && !u.IsNull()
-}
func (u VectorStoreFileBatchNewParamsAttributeUnion) MarshalJSON() ([]byte, error) {
- return param.MarshalUnion[VectorStoreFileBatchNewParamsAttributeUnion](u.OfString, u.OfFloat, u.OfBool)
+ return param.MarshalUnion(u, u.OfString, u.OfFloat, u.OfBool)
+}
+func (u *VectorStoreFileBatchNewParamsAttributeUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, u)
}
func (u *VectorStoreFileBatchNewParamsAttributeUnion) asAny() any {
@@ -340,15 +335,9 @@ type VectorStoreFileBatchListFilesParams struct {
paramObj
}
-// IsPresent returns true if the field's value is not omitted and not the JSON
-// "null". To check if this field is omitted, use [param.IsOmitted].
-func (f VectorStoreFileBatchListFilesParams) IsPresent() bool {
- return !param.IsOmitted(f) && !f.IsNull()
-}
-
// URLQuery serializes [VectorStoreFileBatchListFilesParams]'s query parameters as
// `url.Values`.
-func (r VectorStoreFileBatchListFilesParams) URLQuery() (v url.Values) {
+func (r VectorStoreFileBatchListFilesParams) URLQuery() (v url.Values, err error) {
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -0,0 +1,440 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package webhooks
+
+import (
+ "github.com/openai/openai-go/internal/apierror"
+ "github.com/openai/openai-go/packages/param"
+ "github.com/openai/openai-go/shared"
+)
+
+// aliased to make [param.APIUnion] private when embedding
+type paramUnion = param.APIUnion
+
+// aliased to make [param.APIObject] private when embedding
+type paramObj = param.APIObject
+
+type Error = apierror.Error
+
+// This is an alias to an internal type.
+type ChatModel = shared.ChatModel
+
+// Equals "gpt-4.1"
+const ChatModelGPT4_1 = shared.ChatModelGPT4_1
+
+// Equals "gpt-4.1-mini"
+const ChatModelGPT4_1Mini = shared.ChatModelGPT4_1Mini
+
+// Equals "gpt-4.1-nano"
+const ChatModelGPT4_1Nano = shared.ChatModelGPT4_1Nano
+
+// Equals "gpt-4.1-2025-04-14"
+const ChatModelGPT4_1_2025_04_14 = shared.ChatModelGPT4_1_2025_04_14
+
+// Equals "gpt-4.1-mini-2025-04-14"
+const ChatModelGPT4_1Mini2025_04_14 = shared.ChatModelGPT4_1Mini2025_04_14
+
+// Equals "gpt-4.1-nano-2025-04-14"
+const ChatModelGPT4_1Nano2025_04_14 = shared.ChatModelGPT4_1Nano2025_04_14
+
+// Equals "o4-mini"
+const ChatModelO4Mini = shared.ChatModelO4Mini
+
+// Equals "o4-mini-2025-04-16"
+const ChatModelO4Mini2025_04_16 = shared.ChatModelO4Mini2025_04_16
+
+// Equals "o3"
+const ChatModelO3 = shared.ChatModelO3
+
+// Equals "o3-2025-04-16"
+const ChatModelO3_2025_04_16 = shared.ChatModelO3_2025_04_16
+
+// Equals "o3-mini"
+const ChatModelO3Mini = shared.ChatModelO3Mini
+
+// Equals "o3-mini-2025-01-31"
+const ChatModelO3Mini2025_01_31 = shared.ChatModelO3Mini2025_01_31
+
+// Equals "o1"
+const ChatModelO1 = shared.ChatModelO1
+
+// Equals "o1-2024-12-17"
+const ChatModelO1_2024_12_17 = shared.ChatModelO1_2024_12_17
+
+// Equals "o1-preview"
+const ChatModelO1Preview = shared.ChatModelO1Preview
+
+// Equals "o1-preview-2024-09-12"
+const ChatModelO1Preview2024_09_12 = shared.ChatModelO1Preview2024_09_12
+
+// Equals "o1-mini"
+const ChatModelO1Mini = shared.ChatModelO1Mini
+
+// Equals "o1-mini-2024-09-12"
+const ChatModelO1Mini2024_09_12 = shared.ChatModelO1Mini2024_09_12
+
+// Equals "gpt-4o"
+const ChatModelGPT4o = shared.ChatModelGPT4o
+
+// Equals "gpt-4o-2024-11-20"
+const ChatModelGPT4o2024_11_20 = shared.ChatModelGPT4o2024_11_20
+
+// Equals "gpt-4o-2024-08-06"
+const ChatModelGPT4o2024_08_06 = shared.ChatModelGPT4o2024_08_06
+
+// Equals "gpt-4o-2024-05-13"
+const ChatModelGPT4o2024_05_13 = shared.ChatModelGPT4o2024_05_13
+
+// Equals "gpt-4o-audio-preview"
+const ChatModelGPT4oAudioPreview = shared.ChatModelGPT4oAudioPreview
+
+// Equals "gpt-4o-audio-preview-2024-10-01"
+const ChatModelGPT4oAudioPreview2024_10_01 = shared.ChatModelGPT4oAudioPreview2024_10_01
+
+// Equals "gpt-4o-audio-preview-2024-12-17"
+const ChatModelGPT4oAudioPreview2024_12_17 = shared.ChatModelGPT4oAudioPreview2024_12_17
+
+// Equals "gpt-4o-audio-preview-2025-06-03"
+const ChatModelGPT4oAudioPreview2025_06_03 = shared.ChatModelGPT4oAudioPreview2025_06_03
+
+// Equals "gpt-4o-mini-audio-preview"
+const ChatModelGPT4oMiniAudioPreview = shared.ChatModelGPT4oMiniAudioPreview
+
+// Equals "gpt-4o-mini-audio-preview-2024-12-17"
+const ChatModelGPT4oMiniAudioPreview2024_12_17 = shared.ChatModelGPT4oMiniAudioPreview2024_12_17
+
+// Equals "gpt-4o-search-preview"
+const ChatModelGPT4oSearchPreview = shared.ChatModelGPT4oSearchPreview
+
+// Equals "gpt-4o-mini-search-preview"
+const ChatModelGPT4oMiniSearchPreview = shared.ChatModelGPT4oMiniSearchPreview
+
+// Equals "gpt-4o-search-preview-2025-03-11"
+const ChatModelGPT4oSearchPreview2025_03_11 = shared.ChatModelGPT4oSearchPreview2025_03_11
+
+// Equals "gpt-4o-mini-search-preview-2025-03-11"
+const ChatModelGPT4oMiniSearchPreview2025_03_11 = shared.ChatModelGPT4oMiniSearchPreview2025_03_11
+
+// Equals "chatgpt-4o-latest"
+const ChatModelChatgpt4oLatest = shared.ChatModelChatgpt4oLatest
+
+// Equals "codex-mini-latest"
+const ChatModelCodexMiniLatest = shared.ChatModelCodexMiniLatest
+
+// Equals "gpt-4o-mini"
+const ChatModelGPT4oMini = shared.ChatModelGPT4oMini
+
+// Equals "gpt-4o-mini-2024-07-18"
+const ChatModelGPT4oMini2024_07_18 = shared.ChatModelGPT4oMini2024_07_18
+
+// Equals "gpt-4-turbo"
+const ChatModelGPT4Turbo = shared.ChatModelGPT4Turbo
+
+// Equals "gpt-4-turbo-2024-04-09"
+const ChatModelGPT4Turbo2024_04_09 = shared.ChatModelGPT4Turbo2024_04_09
+
+// Equals "gpt-4-0125-preview"
+const ChatModelGPT4_0125Preview = shared.ChatModelGPT4_0125Preview
+
+// Equals "gpt-4-turbo-preview"
+const ChatModelGPT4TurboPreview = shared.ChatModelGPT4TurboPreview
+
+// Equals "gpt-4-1106-preview"
+const ChatModelGPT4_1106Preview = shared.ChatModelGPT4_1106Preview
+
+// Equals "gpt-4-vision-preview"
+const ChatModelGPT4VisionPreview = shared.ChatModelGPT4VisionPreview
+
+// Equals "gpt-4"
+const ChatModelGPT4 = shared.ChatModelGPT4
+
+// Equals "gpt-4-0314"
+const ChatModelGPT4_0314 = shared.ChatModelGPT4_0314
+
+// Equals "gpt-4-0613"
+const ChatModelGPT4_0613 = shared.ChatModelGPT4_0613
+
+// Equals "gpt-4-32k"
+const ChatModelGPT4_32k = shared.ChatModelGPT4_32k
+
+// Equals "gpt-4-32k-0314"
+const ChatModelGPT4_32k0314 = shared.ChatModelGPT4_32k0314
+
+// Equals "gpt-4-32k-0613"
+const ChatModelGPT4_32k0613 = shared.ChatModelGPT4_32k0613
+
+// Equals "gpt-3.5-turbo"
+const ChatModelGPT3_5Turbo = shared.ChatModelGPT3_5Turbo
+
+// Equals "gpt-3.5-turbo-16k"
+const ChatModelGPT3_5Turbo16k = shared.ChatModelGPT3_5Turbo16k
+
+// Equals "gpt-3.5-turbo-0301"
+const ChatModelGPT3_5Turbo0301 = shared.ChatModelGPT3_5Turbo0301
+
+// Equals "gpt-3.5-turbo-0613"
+const ChatModelGPT3_5Turbo0613 = shared.ChatModelGPT3_5Turbo0613
+
+// Equals "gpt-3.5-turbo-1106"
+const ChatModelGPT3_5Turbo1106 = shared.ChatModelGPT3_5Turbo1106
+
+// Equals "gpt-3.5-turbo-0125"
+const ChatModelGPT3_5Turbo0125 = shared.ChatModelGPT3_5Turbo0125
+
+// Equals "gpt-3.5-turbo-16k-0613"
+const ChatModelGPT3_5Turbo16k0613 = shared.ChatModelGPT3_5Turbo16k0613
+
+// A filter used to compare a specified attribute key to a given value using a
+// defined comparison operation.
+//
+// This is an alias to an internal type.
+type ComparisonFilter = shared.ComparisonFilter
+
+// Specifies the comparison operator: `eq`, `ne`, `gt`, `gte`, `lt`, `lte`.
+//
+// - `eq`: equals
+// - `ne`: not equal
+// - `gt`: greater than
+// - `gte`: greater than or equal
+// - `lt`: less than
+// - `lte`: less than or equal
+//
+// This is an alias to an internal type.
+type ComparisonFilterType = shared.ComparisonFilterType
+
+// Equals "eq"
+const ComparisonFilterTypeEq = shared.ComparisonFilterTypeEq
+
+// Equals "ne"
+const ComparisonFilterTypeNe = shared.ComparisonFilterTypeNe
+
+// Equals "gt"
+const ComparisonFilterTypeGt = shared.ComparisonFilterTypeGt
+
+// Equals "gte"
+const ComparisonFilterTypeGte = shared.ComparisonFilterTypeGte
+
+// Equals "lt"
+const ComparisonFilterTypeLt = shared.ComparisonFilterTypeLt
+
+// Equals "lte"
+const ComparisonFilterTypeLte = shared.ComparisonFilterTypeLte
+
+// The value to compare against the attribute key; supports string, number, or
+// boolean types.
+//
+// This is an alias to an internal type.
+type ComparisonFilterValueUnion = shared.ComparisonFilterValueUnion
+
+// A filter used to compare a specified attribute key to a given value using a
+// defined comparison operation.
+//
+// This is an alias to an internal type.
+type ComparisonFilterParam = shared.ComparisonFilterParam
+
+// The value to compare against the attribute key; supports string, number, or
+// boolean types.
+//
+// This is an alias to an internal type.
+type ComparisonFilterValueUnionParam = shared.ComparisonFilterValueUnionParam
+
+// Combine multiple filters using `and` or `or`.
+//
+// This is an alias to an internal type.
+type CompoundFilter = shared.CompoundFilter
+
+// Type of operation: `and` or `or`.
+//
+// This is an alias to an internal type.
+type CompoundFilterType = shared.CompoundFilterType
+
+// Equals "and"
+const CompoundFilterTypeAnd = shared.CompoundFilterTypeAnd
+
+// Equals "or"
+const CompoundFilterTypeOr = shared.CompoundFilterTypeOr
+
+// Combine multiple filters using `and` or `or`.
+//
+// This is an alias to an internal type.
+type CompoundFilterParam = shared.CompoundFilterParam
+
+// This is an alias to an internal type.
+type ErrorObject = shared.ErrorObject
+
+// This is an alias to an internal type.
+type FunctionDefinition = shared.FunctionDefinition
+
+// This is an alias to an internal type.
+type FunctionDefinitionParam = shared.FunctionDefinitionParam
+
+// The parameters the functions accepts, described as a JSON Schema object. See the
+// [guide](https://platform.openai.com/docs/guides/function-calling) for examples,
+// and the
+// [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for
+// documentation about the format.
+//
+// Omitting `parameters` defines a function with an empty parameter list.
+//
+// This is an alias to an internal type.
+type FunctionParameters = shared.FunctionParameters
+
+// Set of 16 key-value pairs that can be attached to an object. This can be useful
+// for storing additional information about the object in a structured format, and
+// querying for objects via API or the dashboard.
+//
+// Keys are strings with a maximum length of 64 characters. Values are strings with
+// a maximum length of 512 characters.
+//
+// This is an alias to an internal type.
+type Metadata = shared.Metadata
+
+// **o-series models only**
+//
+// Configuration options for
+// [reasoning models](https://platform.openai.com/docs/guides/reasoning).
+//
+// This is an alias to an internal type.
+type Reasoning = shared.Reasoning
+
+// **Deprecated:** use `summary` instead.
+//
+// A summary of the reasoning performed by the model. This can be useful for
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
+//
+// This is an alias to an internal type.
+type ReasoningGenerateSummary = shared.ReasoningGenerateSummary
+
+// Equals "auto"
+const ReasoningGenerateSummaryAuto = shared.ReasoningGenerateSummaryAuto
+
+// Equals "concise"
+const ReasoningGenerateSummaryConcise = shared.ReasoningGenerateSummaryConcise
+
+// Equals "detailed"
+const ReasoningGenerateSummaryDetailed = shared.ReasoningGenerateSummaryDetailed
+
+// A summary of the reasoning performed by the model. This can be useful for
+// debugging and understanding the model's reasoning process. One of `auto`,
+// `concise`, or `detailed`.
+//
+// This is an alias to an internal type.
+type ReasoningSummary = shared.ReasoningSummary
+
+// Equals "auto"
+const ReasoningSummaryAuto = shared.ReasoningSummaryAuto
+
+// Equals "concise"
+const ReasoningSummaryConcise = shared.ReasoningSummaryConcise
+
+// Equals "detailed"
+const ReasoningSummaryDetailed = shared.ReasoningSummaryDetailed
+
+// **o-series models only**
+//
+// Configuration options for
+// [reasoning models](https://platform.openai.com/docs/guides/reasoning).
+//
+// This is an alias to an internal type.
+type ReasoningParam = shared.ReasoningParam
+
+// **o-series models only**
+//
+// Constrains effort on reasoning for
+// [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently
+// supported values are `low`, `medium`, and `high`. Reducing reasoning effort can
+// result in faster responses and fewer tokens used on reasoning in a response.
+//
+// This is an alias to an internal type.
+type ReasoningEffort = shared.ReasoningEffort
+
+// Equals "low"
+const ReasoningEffortLow = shared.ReasoningEffortLow
+
+// Equals "medium"
+const ReasoningEffortMedium = shared.ReasoningEffortMedium
+
+// Equals "high"
+const ReasoningEffortHigh = shared.ReasoningEffortHigh
+
+// JSON object response format. An older method of generating JSON responses. Using
+// `json_schema` is recommended for models that support it. Note that the model
+// will not generate JSON without a system or user message instructing it to do so.
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONObject = shared.ResponseFormatJSONObject
+
+// JSON object response format. An older method of generating JSON responses. Using
+// `json_schema` is recommended for models that support it. Note that the model
+// will not generate JSON without a system or user message instructing it to do so.
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONObjectParam = shared.ResponseFormatJSONObjectParam
+
+// JSON Schema response format. Used to generate structured JSON responses. Learn
+// more about
+// [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONSchema = shared.ResponseFormatJSONSchema
+
+// Structured Outputs configuration options, including a JSON Schema.
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONSchemaJSONSchema = shared.ResponseFormatJSONSchemaJSONSchema
+
+// JSON Schema response format. Used to generate structured JSON responses. Learn
+// more about
+// [Structured Outputs](https://platform.openai.com/docs/guides/structured-outputs).
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONSchemaParam = shared.ResponseFormatJSONSchemaParam
+
+// Structured Outputs configuration options, including a JSON Schema.
+//
+// This is an alias to an internal type.
+type ResponseFormatJSONSchemaJSONSchemaParam = shared.ResponseFormatJSONSchemaJSONSchemaParam
+
+// Default response format. Used to generate text responses.
+//
+// This is an alias to an internal type.
+type ResponseFormatText = shared.ResponseFormatText
+
+// Default response format. Used to generate text responses.
+//
+// This is an alias to an internal type.
+type ResponseFormatTextParam = shared.ResponseFormatTextParam
+
+// This is an alias to an internal type.
+type ResponsesModel = shared.ResponsesModel
+
+// Equals "o1-pro"
+const ResponsesModelO1Pro = shared.ResponsesModelO1Pro
+
+// Equals "o1-pro-2025-03-19"
+const ResponsesModelO1Pro2025_03_19 = shared.ResponsesModelO1Pro2025_03_19
+
+// Equals "o3-pro"
+const ResponsesModelO3Pro = shared.ResponsesModelO3Pro
+
+// Equals "o3-pro-2025-06-10"
+const ResponsesModelO3Pro2025_06_10 = shared.ResponsesModelO3Pro2025_06_10
+
+// Equals "o3-deep-research"
+const ResponsesModelO3DeepResearch = shared.ResponsesModelO3DeepResearch
+
+// Equals "o3-deep-research-2025-06-26"
+const ResponsesModelO3DeepResearch2025_06_26 = shared.ResponsesModelO3DeepResearch2025_06_26
+
+// Equals "o4-mini-deep-research"
+const ResponsesModelO4MiniDeepResearch = shared.ResponsesModelO4MiniDeepResearch
+
+// Equals "o4-mini-deep-research-2025-06-26"
+const ResponsesModelO4MiniDeepResearch2025_06_26 = shared.ResponsesModelO4MiniDeepResearch2025_06_26
+
+// Equals "computer-use-preview"
+const ResponsesModelComputerUsePreview = shared.ResponsesModelComputerUsePreview
+
+// Equals "computer-use-preview-2025-03-11"
+const ResponsesModelComputerUsePreview2025_03_11 = shared.ResponsesModelComputerUsePreview2025_03_11
@@ -0,0 +1,1208 @@
+// File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+package webhooks
+
+import (
+ "crypto/hmac"
+ "crypto/sha256"
+ "crypto/subtle"
+ "encoding/base64"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/openai/openai-go/internal/apijson"
+ "github.com/openai/openai-go/internal/requestconfig"
+ "github.com/openai/openai-go/option"
+ "github.com/openai/openai-go/packages/respjson"
+ "github.com/openai/openai-go/shared/constant"
+)
+
+// WebhookService contains methods and other services that help with interacting
+// with the openai API.
+//
+// Note, unlike clients, this service does not read variables from the environment
+// automatically. You should not instantiate this service directly, and instead use
+// the [NewWebhookService] method instead.
+type WebhookService struct {
+ Options []option.RequestOption
+}
+
+// NewWebhookService generates a new service that applies the given options to each
+// request. These options are applied after the parent client's options (if there
+// is one), and before any request-specific options.
+func NewWebhookService(opts ...option.RequestOption) (r WebhookService) {
+ r = WebhookService{}
+ r.Options = opts
+ return
+}
+
+// Validates that the given payload was sent by OpenAI and parses the payload.
+func (r *WebhookService) Unwrap(body []byte, headers http.Header, opts ...option.RequestOption) (*UnwrapWebhookEventUnion, error) {
+ // Always perform signature verification
+ err := r.VerifySignature(body, headers, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ res := &UnwrapWebhookEventUnion{}
+ err = res.UnmarshalJSON(body)
+ if err != nil {
+ return res, err
+ }
+ return res, nil
+}
+
+// UnwrapWithTolerance validates that the given payload was sent by OpenAI using custom tolerance, then parses the payload.
+// tolerance specifies the maximum age of the webhook.
+func (r *WebhookService) UnwrapWithTolerance(body []byte, headers http.Header, tolerance time.Duration, opts ...option.RequestOption) (*UnwrapWebhookEventUnion, error) {
+ err := r.VerifySignatureWithTolerance(body, headers, tolerance, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ res := &UnwrapWebhookEventUnion{}
+ err = res.UnmarshalJSON(body)
+ if err != nil {
+ return res, err
+ }
+ return res, nil
+}
+
+// UnwrapWithToleranceAndTime validates that the given payload was sent by OpenAI using custom tolerance and time, then parses the payload.
+// tolerance specifies the maximum age of the webhook.
+// now allows specifying the current time for testing purposes.
+func (r *WebhookService) UnwrapWithToleranceAndTime(body []byte, headers http.Header, tolerance time.Duration, now time.Time, opts ...option.RequestOption) (*UnwrapWebhookEventUnion, error) {
+ err := r.VerifySignatureWithToleranceAndTime(body, headers, tolerance, now, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ res := &UnwrapWebhookEventUnion{}
+ err = res.UnmarshalJSON(body)
+ if err != nil {
+ return res, err
+ }
+ return res, nil
+}
+
+// VerifySignature validates whether or not the webhook payload was sent by OpenAI.
+// An error will be raised if the webhook signature is invalid.
+// tolerance specifies the maximum age of the webhook (default: 5 minutes).
+func (r *WebhookService) VerifySignature(body []byte, headers http.Header, opts ...option.RequestOption) error {
+ return r.VerifySignatureWithTolerance(body, headers, 5*time.Minute, opts...)
+}
+
+// VerifySignatureWithTolerance validates whether or not the webhook payload was sent by OpenAI.
+// An error will be raised if the webhook signature is invalid.
+// tolerance specifies the maximum age of the webhook.
+func (r *WebhookService) VerifySignatureWithTolerance(body []byte, headers http.Header, tolerance time.Duration, opts ...option.RequestOption) error {
+ return r.VerifySignatureWithToleranceAndTime(body, headers, tolerance, time.Now(), opts...)
+}
+
+// VerifySignatureWithToleranceAndTime validates whether or not the webhook payload was sent by OpenAI.
+// An error will be raised if the webhook signature is invalid.
+// tolerance specifies the maximum age of the webhook.
+// now allows specifying the current time for testing purposes.
+func (r *WebhookService) VerifySignatureWithToleranceAndTime(body []byte, headers http.Header, tolerance time.Duration, now time.Time, opts ...option.RequestOption) error {
+ cfg, err := requestconfig.PreRequestOptions(r.Options...)
+ if err != nil {
+ return err
+ }
+ webhookSecret := cfg.WebhookSecret
+
+ if webhookSecret == "" {
+ return errors.New("webhook secret must be provided either in the method call or configured on the client")
+ }
+
+ if headers == nil {
+ return errors.New("headers are required for webhook verification")
+ }
+
+ // Extract required headers
+ signatureHeader := headers.Get("webhook-signature")
+ if signatureHeader == "" {
+ return errors.New("missing required webhook-signature header")
+ }
+
+ timestampHeader := headers.Get("webhook-timestamp")
+ if timestampHeader == "" {
+ return errors.New("missing required webhook-timestamp header")
+ }
+
+ webhookID := headers.Get("webhook-id")
+ if webhookID == "" {
+ return errors.New("missing required webhook-id header")
+ }
+
+ // Validate timestamp to prevent replay attacks
+ timestampSeconds, err := strconv.ParseInt(timestampHeader, 10, 64)
+ if err != nil {
+ return errors.New("invalid webhook timestamp format")
+ }
+
+ nowUnix := now.Unix()
+ toleranceSeconds := int64(tolerance.Seconds())
+
+ if nowUnix-timestampSeconds > toleranceSeconds {
+ return errors.New("webhook timestamp is too old")
+ }
+
+ if timestampSeconds > nowUnix+toleranceSeconds {
+ return errors.New("webhook timestamp is too new")
+ }
+
+ // Extract signatures from v1,<base64> format
+ // The signature header can have multiple values, separated by spaces.
+ // Each value is in the format v1,<base64>. We should accept if any match.
+ var signatures []string
+ for _, part := range strings.Fields(signatureHeader) {
+ if strings.HasPrefix(part, "v1,") {
+ signatures = append(signatures, part[3:])
+ } else {
+ signatures = append(signatures, part)
+ }
+ }
+
+ // Decode the secret if it starts with whsec_
+ var decodedSecret []byte
+ if strings.HasPrefix(webhookSecret, "whsec_") {
+ decodedSecret, err = base64.StdEncoding.DecodeString(webhookSecret[6:])
+ if err != nil {
+ return fmt.Errorf("invalid webhook secret format: %v", err)
+ }
+ } else {
+ decodedSecret = []byte(webhookSecret)
+ }
+
+ // Create the signed payload: {webhook_id}.{timestamp}.{payload}
+ signedPayload := fmt.Sprintf("%s.%s.%s", webhookID, timestampHeader, string(body))
+
+ // Compute HMAC-SHA256 signature
+ h := hmac.New(sha256.New, decodedSecret)
+ h.Write([]byte(signedPayload))
+ expectedSignature := base64.StdEncoding.EncodeToString(h.Sum(nil))
+
+ // Accept if any signature matches using timing-safe comparison
+ for _, signature := range signatures {
+ if subtle.ConstantTimeCompare([]byte(expectedSignature), []byte(signature)) == 1 {
+ return nil
+ }
+ }
+
+ return errors.New("webhook signature verification failed")
+}
+
+// Sent when a batch API request has been cancelled.
+type BatchCancelledWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the batch API request was cancelled.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data BatchCancelledWebhookEventData `json:"data,required"`
+ // The type of the event. Always `batch.cancelled`.
+ Type constant.BatchCancelled `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object BatchCancelledWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchCancelledWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *BatchCancelledWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type BatchCancelledWebhookEventData struct {
+ // The unique ID of the batch API request.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchCancelledWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *BatchCancelledWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type BatchCancelledWebhookEventObject string
+
+const (
+ BatchCancelledWebhookEventObjectEvent BatchCancelledWebhookEventObject = "event"
+)
+
+// Sent when a batch API request has been completed.
+type BatchCompletedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the batch API request was completed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data BatchCompletedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `batch.completed`.
+ Type constant.BatchCompleted `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object BatchCompletedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchCompletedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *BatchCompletedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type BatchCompletedWebhookEventData struct {
+ // The unique ID of the batch API request.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchCompletedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *BatchCompletedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type BatchCompletedWebhookEventObject string
+
+const (
+ BatchCompletedWebhookEventObjectEvent BatchCompletedWebhookEventObject = "event"
+)
+
+// Sent when a batch API request has expired.
+type BatchExpiredWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the batch API request expired.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data BatchExpiredWebhookEventData `json:"data,required"`
+ // The type of the event. Always `batch.expired`.
+ Type constant.BatchExpired `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object BatchExpiredWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchExpiredWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *BatchExpiredWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type BatchExpiredWebhookEventData struct {
+ // The unique ID of the batch API request.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchExpiredWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *BatchExpiredWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type BatchExpiredWebhookEventObject string
+
+const (
+ BatchExpiredWebhookEventObjectEvent BatchExpiredWebhookEventObject = "event"
+)
+
+// Sent when a batch API request has failed.
+type BatchFailedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the batch API request failed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data BatchFailedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `batch.failed`.
+ Type constant.BatchFailed `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object BatchFailedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchFailedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *BatchFailedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type BatchFailedWebhookEventData struct {
+ // The unique ID of the batch API request.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r BatchFailedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *BatchFailedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type BatchFailedWebhookEventObject string
+
+const (
+ BatchFailedWebhookEventObjectEvent BatchFailedWebhookEventObject = "event"
+)
+
+// Sent when an eval run has been canceled.
+type EvalRunCanceledWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the eval run was canceled.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data EvalRunCanceledWebhookEventData `json:"data,required"`
+ // The type of the event. Always `eval.run.canceled`.
+ Type constant.EvalRunCanceled `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object EvalRunCanceledWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunCanceledWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunCanceledWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type EvalRunCanceledWebhookEventData struct {
+ // The unique ID of the eval run.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunCanceledWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunCanceledWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type EvalRunCanceledWebhookEventObject string
+
+const (
+ EvalRunCanceledWebhookEventObjectEvent EvalRunCanceledWebhookEventObject = "event"
+)
+
+// Sent when an eval run has failed.
+type EvalRunFailedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the eval run failed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data EvalRunFailedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `eval.run.failed`.
+ Type constant.EvalRunFailed `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object EvalRunFailedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunFailedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunFailedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type EvalRunFailedWebhookEventData struct {
+ // The unique ID of the eval run.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunFailedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunFailedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type EvalRunFailedWebhookEventObject string
+
+const (
+ EvalRunFailedWebhookEventObjectEvent EvalRunFailedWebhookEventObject = "event"
+)
+
+// Sent when an eval run has succeeded.
+type EvalRunSucceededWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the eval run succeeded.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data EvalRunSucceededWebhookEventData `json:"data,required"`
+ // The type of the event. Always `eval.run.succeeded`.
+ Type constant.EvalRunSucceeded `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object EvalRunSucceededWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunSucceededWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunSucceededWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type EvalRunSucceededWebhookEventData struct {
+ // The unique ID of the eval run.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r EvalRunSucceededWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *EvalRunSucceededWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type EvalRunSucceededWebhookEventObject string
+
+const (
+ EvalRunSucceededWebhookEventObjectEvent EvalRunSucceededWebhookEventObject = "event"
+)
+
+// Sent when a fine-tuning job has been cancelled.
+type FineTuningJobCancelledWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the fine-tuning job was cancelled.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data FineTuningJobCancelledWebhookEventData `json:"data,required"`
+ // The type of the event. Always `fine_tuning.job.cancelled`.
+ Type constant.FineTuningJobCancelled `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object FineTuningJobCancelledWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobCancelledWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobCancelledWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type FineTuningJobCancelledWebhookEventData struct {
+ // The unique ID of the fine-tuning job.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobCancelledWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobCancelledWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type FineTuningJobCancelledWebhookEventObject string
+
+const (
+ FineTuningJobCancelledWebhookEventObjectEvent FineTuningJobCancelledWebhookEventObject = "event"
+)
+
+// Sent when a fine-tuning job has failed.
+type FineTuningJobFailedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the fine-tuning job failed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data FineTuningJobFailedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `fine_tuning.job.failed`.
+ Type constant.FineTuningJobFailed `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object FineTuningJobFailedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobFailedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobFailedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type FineTuningJobFailedWebhookEventData struct {
+ // The unique ID of the fine-tuning job.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobFailedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobFailedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type FineTuningJobFailedWebhookEventObject string
+
+const (
+ FineTuningJobFailedWebhookEventObjectEvent FineTuningJobFailedWebhookEventObject = "event"
+)
+
+// Sent when a fine-tuning job has succeeded.
+type FineTuningJobSucceededWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the fine-tuning job succeeded.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data FineTuningJobSucceededWebhookEventData `json:"data,required"`
+ // The type of the event. Always `fine_tuning.job.succeeded`.
+ Type constant.FineTuningJobSucceeded `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object FineTuningJobSucceededWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobSucceededWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobSucceededWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type FineTuningJobSucceededWebhookEventData struct {
+ // The unique ID of the fine-tuning job.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r FineTuningJobSucceededWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *FineTuningJobSucceededWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type FineTuningJobSucceededWebhookEventObject string
+
+const (
+ FineTuningJobSucceededWebhookEventObjectEvent FineTuningJobSucceededWebhookEventObject = "event"
+)
+
+// Sent when a background response has been cancelled.
+type ResponseCancelledWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the model response was cancelled.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data ResponseCancelledWebhookEventData `json:"data,required"`
+ // The type of the event. Always `response.cancelled`.
+ Type constant.ResponseCancelled `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object ResponseCancelledWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseCancelledWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *ResponseCancelledWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type ResponseCancelledWebhookEventData struct {
+ // The unique ID of the model response.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseCancelledWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *ResponseCancelledWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type ResponseCancelledWebhookEventObject string
+
+const (
+ ResponseCancelledWebhookEventObjectEvent ResponseCancelledWebhookEventObject = "event"
+)
+
+// Sent when a background response has been completed.
+type ResponseCompletedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the model response was completed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data ResponseCompletedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `response.completed`.
+ Type constant.ResponseCompleted `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object ResponseCompletedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseCompletedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *ResponseCompletedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type ResponseCompletedWebhookEventData struct {
+ // The unique ID of the model response.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseCompletedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *ResponseCompletedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type ResponseCompletedWebhookEventObject string
+
+const (
+ ResponseCompletedWebhookEventObjectEvent ResponseCompletedWebhookEventObject = "event"
+)
+
+// Sent when a background response has failed.
+type ResponseFailedWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the model response failed.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data ResponseFailedWebhookEventData `json:"data,required"`
+ // The type of the event. Always `response.failed`.
+ Type constant.ResponseFailed `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object ResponseFailedWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseFailedWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *ResponseFailedWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type ResponseFailedWebhookEventData struct {
+ // The unique ID of the model response.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseFailedWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *ResponseFailedWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type ResponseFailedWebhookEventObject string
+
+const (
+ ResponseFailedWebhookEventObjectEvent ResponseFailedWebhookEventObject = "event"
+)
+
+// Sent when a background response has been interrupted.
+type ResponseIncompleteWebhookEvent struct {
+ // The unique ID of the event.
+ ID string `json:"id,required"`
+ // The Unix timestamp (in seconds) of when the model response was interrupted.
+ CreatedAt int64 `json:"created_at,required"`
+ // Event data payload.
+ Data ResponseIncompleteWebhookEventData `json:"data,required"`
+ // The type of the event. Always `response.incomplete`.
+ Type constant.ResponseIncomplete `json:"type,required"`
+ // The object of the event. Always `event`.
+ //
+ // Any of "event".
+ Object ResponseIncompleteWebhookEventObject `json:"object"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseIncompleteWebhookEvent) RawJSON() string { return r.JSON.raw }
+func (r *ResponseIncompleteWebhookEvent) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// Event data payload.
+type ResponseIncompleteWebhookEventData struct {
+ // The unique ID of the model response.
+ ID string `json:"id,required"`
+ // JSON contains metadata for fields, check presence with [respjson.Field.Valid].
+ JSON struct {
+ ID respjson.Field
+ ExtraFields map[string]respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// Returns the unmodified JSON received from the API
+func (r ResponseIncompleteWebhookEventData) RawJSON() string { return r.JSON.raw }
+func (r *ResponseIncompleteWebhookEventData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// The object of the event. Always `event`.
+type ResponseIncompleteWebhookEventObject string
+
+const (
+ ResponseIncompleteWebhookEventObjectEvent ResponseIncompleteWebhookEventObject = "event"
+)
+
+// UnwrapWebhookEventUnion contains all possible properties and values from
+// [BatchCancelledWebhookEvent], [BatchCompletedWebhookEvent],
+// [BatchExpiredWebhookEvent], [BatchFailedWebhookEvent],
+// [EvalRunCanceledWebhookEvent], [EvalRunFailedWebhookEvent],
+// [EvalRunSucceededWebhookEvent], [FineTuningJobCancelledWebhookEvent],
+// [FineTuningJobFailedWebhookEvent], [FineTuningJobSucceededWebhookEvent],
+// [ResponseCancelledWebhookEvent], [ResponseCompletedWebhookEvent],
+// [ResponseFailedWebhookEvent], [ResponseIncompleteWebhookEvent].
+//
+// Use the [UnwrapWebhookEventUnion.AsAny] method to switch on the variant.
+//
+// Use the methods beginning with 'As' to cast the union to one of its variants.
+type UnwrapWebhookEventUnion struct {
+ ID string `json:"id"`
+ CreatedAt int64 `json:"created_at"`
+ // This field is a union of [BatchCancelledWebhookEventData],
+ // [BatchCompletedWebhookEventData], [BatchExpiredWebhookEventData],
+ // [BatchFailedWebhookEventData], [EvalRunCanceledWebhookEventData],
+ // [EvalRunFailedWebhookEventData], [EvalRunSucceededWebhookEventData],
+ // [FineTuningJobCancelledWebhookEventData], [FineTuningJobFailedWebhookEventData],
+ // [FineTuningJobSucceededWebhookEventData], [ResponseCancelledWebhookEventData],
+ // [ResponseCompletedWebhookEventData], [ResponseFailedWebhookEventData],
+ // [ResponseIncompleteWebhookEventData]
+ Data UnwrapWebhookEventUnionData `json:"data"`
+ // Any of "batch.cancelled", "batch.completed", "batch.expired", "batch.failed",
+ // "eval.run.canceled", "eval.run.failed", "eval.run.succeeded",
+ // "fine_tuning.job.cancelled", "fine_tuning.job.failed",
+ // "fine_tuning.job.succeeded", "response.cancelled", "response.completed",
+ // "response.failed", "response.incomplete".
+ Type string `json:"type"`
+ Object string `json:"object"`
+ JSON struct {
+ ID respjson.Field
+ CreatedAt respjson.Field
+ Data respjson.Field
+ Type respjson.Field
+ Object respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+// anyUnwrapWebhookEvent is implemented by each variant of
+// [UnwrapWebhookEventUnion] to add type safety for the return type of
+// [UnwrapWebhookEventUnion.AsAny]
+type anyUnwrapWebhookEvent interface {
+ implUnwrapWebhookEventUnion()
+}
+
+func (BatchCancelledWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (BatchCompletedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (BatchExpiredWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (BatchFailedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (EvalRunCanceledWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (EvalRunFailedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (EvalRunSucceededWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (FineTuningJobCancelledWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (FineTuningJobFailedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (FineTuningJobSucceededWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (ResponseCancelledWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (ResponseCompletedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (ResponseFailedWebhookEvent) implUnwrapWebhookEventUnion() {}
+func (ResponseIncompleteWebhookEvent) implUnwrapWebhookEventUnion() {}
+
+// Use the following switch statement to find the correct variant
+//
+// switch variant := UnwrapWebhookEventUnion.AsAny().(type) {
+// case webhooks.BatchCancelledWebhookEvent:
+// case webhooks.BatchCompletedWebhookEvent:
+// case webhooks.BatchExpiredWebhookEvent:
+// case webhooks.BatchFailedWebhookEvent:
+// case webhooks.EvalRunCanceledWebhookEvent:
+// case webhooks.EvalRunFailedWebhookEvent:
+// case webhooks.EvalRunSucceededWebhookEvent:
+// case webhooks.FineTuningJobCancelledWebhookEvent:
+// case webhooks.FineTuningJobFailedWebhookEvent:
+// case webhooks.FineTuningJobSucceededWebhookEvent:
+// case webhooks.ResponseCancelledWebhookEvent:
+// case webhooks.ResponseCompletedWebhookEvent:
+// case webhooks.ResponseFailedWebhookEvent:
+// case webhooks.ResponseIncompleteWebhookEvent:
+// default:
+// fmt.Errorf("no variant present")
+// }
+func (u UnwrapWebhookEventUnion) AsAny() anyUnwrapWebhookEvent {
+ switch u.Type {
+ case "batch.cancelled":
+ return u.AsBatchCancelled()
+ case "batch.completed":
+ return u.AsBatchCompleted()
+ case "batch.expired":
+ return u.AsBatchExpired()
+ case "batch.failed":
+ return u.AsBatchFailed()
+ case "eval.run.canceled":
+ return u.AsEvalRunCanceled()
+ case "eval.run.failed":
+ return u.AsEvalRunFailed()
+ case "eval.run.succeeded":
+ return u.AsEvalRunSucceeded()
+ case "fine_tuning.job.cancelled":
+ return u.AsFineTuningJobCancelled()
+ case "fine_tuning.job.failed":
+ return u.AsFineTuningJobFailed()
+ case "fine_tuning.job.succeeded":
+ return u.AsFineTuningJobSucceeded()
+ case "response.cancelled":
+ return u.AsResponseCancelled()
+ case "response.completed":
+ return u.AsResponseCompleted()
+ case "response.failed":
+ return u.AsResponseFailed()
+ case "response.incomplete":
+ return u.AsResponseIncomplete()
+ }
+ return nil
+}
+
+func (u UnwrapWebhookEventUnion) AsBatchCancelled() (v BatchCancelledWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsBatchCompleted() (v BatchCompletedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsBatchExpired() (v BatchExpiredWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsBatchFailed() (v BatchFailedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsEvalRunCanceled() (v EvalRunCanceledWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsEvalRunFailed() (v EvalRunFailedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsEvalRunSucceeded() (v EvalRunSucceededWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsFineTuningJobCancelled() (v FineTuningJobCancelledWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsFineTuningJobFailed() (v FineTuningJobFailedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsFineTuningJobSucceeded() (v FineTuningJobSucceededWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsResponseCancelled() (v ResponseCancelledWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsResponseCompleted() (v ResponseCompletedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsResponseFailed() (v ResponseFailedWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+func (u UnwrapWebhookEventUnion) AsResponseIncomplete() (v ResponseIncompleteWebhookEvent) {
+ apijson.UnmarshalRoot(json.RawMessage(u.JSON.raw), &v)
+ return
+}
+
+// Returns the unmodified JSON received from the API
+func (u UnwrapWebhookEventUnion) RawJSON() string { return u.JSON.raw }
+
+func (r *UnwrapWebhookEventUnion) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
+
+// UnwrapWebhookEventUnionData is an implicit subunion of
+// [UnwrapWebhookEventUnion]. UnwrapWebhookEventUnionData provides convenient
+// access to the sub-properties of the union.
+//
+// For type safety it is recommended to directly use a variant of the
+// [UnwrapWebhookEventUnion].
+type UnwrapWebhookEventUnionData struct {
+ ID string `json:"id"`
+ JSON struct {
+ ID respjson.Field
+ raw string
+ } `json:"-"`
+}
+
+func (r *UnwrapWebhookEventUnionData) UnmarshalJSON(data []byte) error {
+ return apijson.UnmarshalRoot(data, r)
+}
@@ -0,0 +1,25 @@
+# Compiled Object files, Static and Dynamic libs (Shared Objects)
+*.o
+*.a
+*.so
+
+# Folders
+_obj
+_test
+
+# Architecture specific extensions/prefixes
+*.[568vq]
+[568vq].out
+
+*.cgo1.go
+*.cgo2.c
+_cgo_defun.c
+_cgo_gotypes.go
+_cgo_export.*
+
+_testmain.go
+
+*.exe
+*.test
+
+*.bench
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2014 Steve Francia
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
@@ -0,0 +1,40 @@
+GOVERSION := $(shell go version | cut -d ' ' -f 3 | cut -d '.' -f 2)
+
+.PHONY: check fmt lint test test-race vet test-cover-html help
+.DEFAULT_GOAL := help
+
+check: test-race fmt vet lint ## Run tests and linters
+
+test: ## Run tests
+ go test ./...
+
+test-race: ## Run tests with race detector
+ go test -race ./...
+
+fmt: ## Run gofmt linter
+ifeq "$(GOVERSION)" "12"
+ @for d in `go list` ; do \
+ if [ "`gofmt -l -s $$GOPATH/src/$$d | tee /dev/stderr`" ]; then \
+ echo "^ improperly formatted go files" && echo && exit 1; \
+ fi \
+ done
+endif
+
+lint: ## Run golint linter
+ @for d in `go list` ; do \
+ if [ "`golint $$d | tee /dev/stderr`" ]; then \
+ echo "^ golint errors!" && echo && exit 1; \
+ fi \
+ done
+
+vet: ## Run go vet linter
+ @if [ "`go vet | tee /dev/stderr`" ]; then \
+ echo "^ go vet errors!" && echo && exit 1; \
+ fi
+
+test-cover-html: ## Generate test coverage report
+ go test -coverprofile=coverage.out -covermode=count
+ go tool cover -func=coverage.out
+
+help:
+ @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
@@ -0,0 +1,75 @@
+# cast
+
+[](https://github.com/spf13/cast/actions/workflows/test.yaml)
+[](https://pkg.go.dev/mod/github.com/spf13/cast)
+
+[](https://goreportcard.com/report/github.com/spf13/cast)
+
+Easy and safe casting from one type to another in Go
+
+Donโt Panic! ... Cast
+
+## What is Cast?
+
+Cast is a library to convert between different go types in a consistent and easy way.
+
+Cast provides simple functions to easily convert a number to a string, an
+interface into a bool, etc. Cast does this intelligently when an obvious
+conversion is possible. It doesnโt make any attempts to guess what you meant,
+for example you can only convert a string to an int when it is a string
+representation of an int such as โ8โ. Cast was developed for use in
+[Hugo](https://gohugo.io), a website engine which uses YAML, TOML or JSON
+for meta data.
+
+## Why use Cast?
+
+When working with dynamic data in Go you often need to cast or convert the data
+from one type into another. Cast goes beyond just using type assertion (though
+it uses that when possible) to provide a very straightforward and convenient
+library.
+
+If you are working with interfaces to handle things like dynamic content
+youโll need an easy way to convert an interface into a given type. This
+is the library for you.
+
+If you are taking in data from YAML, TOML or JSON or other formats which lack
+full types, then Cast is the library for you.
+
+## Usage
+
+Cast provides a handful of To_____ methods. These methods will always return
+the desired type. **If input is provided that will not convert to that type, the
+0 or nil value for that type will be returned**.
+
+Cast also provides identical methods To_____E. These return the same result as
+the To_____ methods, plus an additional error which tells you if it successfully
+converted. Using these methods you can tell the difference between when the
+input matched the zero value or when the conversion failed and the zero value
+was returned.
+
+The following examples are merely a sample of what is available. Please review
+the code for a complete set.
+
+### Example โToStringโ:
+
+ cast.ToString("mayonegg") // "mayonegg"
+ cast.ToString(8) // "8"
+ cast.ToString(8.31) // "8.31"
+ cast.ToString([]byte("one time")) // "one time"
+ cast.ToString(nil) // ""
+
+ var foo interface{} = "one more time"
+ cast.ToString(foo) // "one more time"
+
+
+### Example โToIntโ:
+
+ cast.ToInt(8) // 8
+ cast.ToInt(8.31) // 8
+ cast.ToInt("8") // 8
+ cast.ToInt(true) // 1
+ cast.ToInt(false) // 0
+
+ var eight interface{} = 8
+ cast.ToInt(eight) // 8
+ cast.ToInt(nil) // 0
@@ -0,0 +1,176 @@
+// Copyright ยฉ 2014 Steve Francia <spf@spf13.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+// Package cast provides easy and safe casting in Go.
+package cast
+
+import "time"
+
+// ToBool casts an interface to a bool type.
+func ToBool(i interface{}) bool {
+ v, _ := ToBoolE(i)
+ return v
+}
+
+// ToTime casts an interface to a time.Time type.
+func ToTime(i interface{}) time.Time {
+ v, _ := ToTimeE(i)
+ return v
+}
+
+func ToTimeInDefaultLocation(i interface{}, location *time.Location) time.Time {
+ v, _ := ToTimeInDefaultLocationE(i, location)
+ return v
+}
+
+// ToDuration casts an interface to a time.Duration type.
+func ToDuration(i interface{}) time.Duration {
+ v, _ := ToDurationE(i)
+ return v
+}
+
+// ToFloat64 casts an interface to a float64 type.
+func ToFloat64(i interface{}) float64 {
+ v, _ := ToFloat64E(i)
+ return v
+}
+
+// ToFloat32 casts an interface to a float32 type.
+func ToFloat32(i interface{}) float32 {
+ v, _ := ToFloat32E(i)
+ return v
+}
+
+// ToInt64 casts an interface to an int64 type.
+func ToInt64(i interface{}) int64 {
+ v, _ := ToInt64E(i)
+ return v
+}
+
+// ToInt32 casts an interface to an int32 type.
+func ToInt32(i interface{}) int32 {
+ v, _ := ToInt32E(i)
+ return v
+}
+
+// ToInt16 casts an interface to an int16 type.
+func ToInt16(i interface{}) int16 {
+ v, _ := ToInt16E(i)
+ return v
+}
+
+// ToInt8 casts an interface to an int8 type.
+func ToInt8(i interface{}) int8 {
+ v, _ := ToInt8E(i)
+ return v
+}
+
+// ToInt casts an interface to an int type.
+func ToInt(i interface{}) int {
+ v, _ := ToIntE(i)
+ return v
+}
+
+// ToUint casts an interface to a uint type.
+func ToUint(i interface{}) uint {
+ v, _ := ToUintE(i)
+ return v
+}
+
+// ToUint64 casts an interface to a uint64 type.
+func ToUint64(i interface{}) uint64 {
+ v, _ := ToUint64E(i)
+ return v
+}
+
+// ToUint32 casts an interface to a uint32 type.
+func ToUint32(i interface{}) uint32 {
+ v, _ := ToUint32E(i)
+ return v
+}
+
+// ToUint16 casts an interface to a uint16 type.
+func ToUint16(i interface{}) uint16 {
+ v, _ := ToUint16E(i)
+ return v
+}
+
+// ToUint8 casts an interface to a uint8 type.
+func ToUint8(i interface{}) uint8 {
+ v, _ := ToUint8E(i)
+ return v
+}
+
+// ToString casts an interface to a string type.
+func ToString(i interface{}) string {
+ v, _ := ToStringE(i)
+ return v
+}
+
+// ToStringMapString casts an interface to a map[string]string type.
+func ToStringMapString(i interface{}) map[string]string {
+ v, _ := ToStringMapStringE(i)
+ return v
+}
+
+// ToStringMapStringSlice casts an interface to a map[string][]string type.
+func ToStringMapStringSlice(i interface{}) map[string][]string {
+ v, _ := ToStringMapStringSliceE(i)
+ return v
+}
+
+// ToStringMapBool casts an interface to a map[string]bool type.
+func ToStringMapBool(i interface{}) map[string]bool {
+ v, _ := ToStringMapBoolE(i)
+ return v
+}
+
+// ToStringMapInt casts an interface to a map[string]int type.
+func ToStringMapInt(i interface{}) map[string]int {
+ v, _ := ToStringMapIntE(i)
+ return v
+}
+
+// ToStringMapInt64 casts an interface to a map[string]int64 type.
+func ToStringMapInt64(i interface{}) map[string]int64 {
+ v, _ := ToStringMapInt64E(i)
+ return v
+}
+
+// ToStringMap casts an interface to a map[string]interface{} type.
+func ToStringMap(i interface{}) map[string]interface{} {
+ v, _ := ToStringMapE(i)
+ return v
+}
+
+// ToSlice casts an interface to a []interface{} type.
+func ToSlice(i interface{}) []interface{} {
+ v, _ := ToSliceE(i)
+ return v
+}
+
+// ToBoolSlice casts an interface to a []bool type.
+func ToBoolSlice(i interface{}) []bool {
+ v, _ := ToBoolSliceE(i)
+ return v
+}
+
+// ToStringSlice casts an interface to a []string type.
+func ToStringSlice(i interface{}) []string {
+ v, _ := ToStringSliceE(i)
+ return v
+}
+
+// ToIntSlice casts an interface to a []int type.
+func ToIntSlice(i interface{}) []int {
+ v, _ := ToIntSliceE(i)
+ return v
+}
+
+// ToDurationSlice casts an interface to a []time.Duration type.
+func ToDurationSlice(i interface{}) []time.Duration {
+ v, _ := ToDurationSliceE(i)
+ return v
+}
@@ -0,0 +1,1510 @@
+// Copyright ยฉ 2014 Steve Francia <spf@spf13.com>.
+//
+// Use of this source code is governed by an MIT-style
+// license that can be found in the LICENSE file.
+
+package cast
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "html/template"
+ "reflect"
+ "strconv"
+ "strings"
+ "time"
+)
+
+var errNegativeNotAllowed = errors.New("unable to cast negative value")
+
+type float64EProvider interface {
+ Float64() (float64, error)
+}
+
+type float64Provider interface {
+ Float64() float64
+}
+
+// ToTimeE casts an interface to a time.Time type.
+func ToTimeE(i interface{}) (tim time.Time, err error) {
+ return ToTimeInDefaultLocationE(i, time.UTC)
+}
+
+// ToTimeInDefaultLocationE casts an empty interface to time.Time,
+// interpreting inputs without a timezone to be in the given location,
+// or the local timezone if nil.
+func ToTimeInDefaultLocationE(i interface{}, location *time.Location) (tim time.Time, err error) {
+ i = indirect(i)
+
+ switch v := i.(type) {
+ case time.Time:
+ return v, nil
+ case string:
+ return StringToDateInDefaultLocation(v, location)
+ case json.Number:
+ s, err1 := ToInt64E(v)
+ if err1 != nil {
+ return time.Time{}, fmt.Errorf("unable to cast %#v of type %T to Time", i, i)
+ }
+ return time.Unix(s, 0), nil
+ case int:
+ return time.Unix(int64(v), 0), nil
+ case int64:
+ return time.Unix(v, 0), nil
+ case int32:
+ return time.Unix(int64(v), 0), nil
+ case uint:
+ return time.Unix(int64(v), 0), nil
+ case uint64:
+ return time.Unix(int64(v), 0), nil
+ case uint32:
+ return time.Unix(int64(v), 0), nil
+ default:
+ return time.Time{}, fmt.Errorf("unable to cast %#v of type %T to Time", i, i)
+ }
+}
+
+// ToDurationE casts an interface to a time.Duration type.
+func ToDurationE(i interface{}) (d time.Duration, err error) {
+ i = indirect(i)
+
+ switch s := i.(type) {
+ case time.Duration:
+ return s, nil
+ case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8:
+ d = time.Duration(ToInt64(s))
+ return
+ case float32, float64:
+ d = time.Duration(ToFloat64(s))
+ return
+ case string:
+ if strings.ContainsAny(s, "nsuยตmh") {
+ d, err = time.ParseDuration(s)
+ } else {
+ d, err = time.ParseDuration(s + "ns")
+ }
+ return
+ case float64EProvider:
+ var v float64
+ v, err = s.Float64()
+ d = time.Duration(v)
+ return
+ case float64Provider:
+ d = time.Duration(s.Float64())
+ return
+ default:
+ err = fmt.Errorf("unable to cast %#v of type %T to Duration", i, i)
+ return
+ }
+}
+
+// ToBoolE casts an interface to a bool type.
+func ToBoolE(i interface{}) (bool, error) {
+ i = indirect(i)
+
+ switch b := i.(type) {
+ case bool:
+ return b, nil
+ case nil:
+ return false, nil
+ case int:
+ return b != 0, nil
+ case int64:
+ return b != 0, nil
+ case int32:
+ return b != 0, nil
+ case int16:
+ return b != 0, nil
+ case int8:
+ return b != 0, nil
+ case uint:
+ return b != 0, nil
+ case uint64:
+ return b != 0, nil
+ case uint32:
+ return b != 0, nil
+ case uint16:
+ return b != 0, nil
+ case uint8:
+ return b != 0, nil
+ case float64:
+ return b != 0, nil
+ case float32:
+ return b != 0, nil
+ case time.Duration:
+ return b != 0, nil
+ case string:
+ return strconv.ParseBool(i.(string))
+ case json.Number:
+ v, err := ToInt64E(b)
+ if err == nil {
+ return v != 0, nil
+ }
+ return false, fmt.Errorf("unable to cast %#v of type %T to bool", i, i)
+ default:
+ return false, fmt.Errorf("unable to cast %#v of type %T to bool", i, i)
+ }
+}
+
+// ToFloat64E casts an interface to a float64 type.
+func ToFloat64E(i interface{}) (float64, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return float64(intv), nil
+ }
+
+ switch s := i.(type) {
+ case float64:
+ return s, nil
+ case float32:
+ return float64(s), nil
+ case int64:
+ return float64(s), nil
+ case int32:
+ return float64(s), nil
+ case int16:
+ return float64(s), nil
+ case int8:
+ return float64(s), nil
+ case uint:
+ return float64(s), nil
+ case uint64:
+ return float64(s), nil
+ case uint32:
+ return float64(s), nil
+ case uint16:
+ return float64(s), nil
+ case uint8:
+ return float64(s), nil
+ case string:
+ v, err := strconv.ParseFloat(s, 64)
+ if err == nil {
+ return v, nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i)
+ case float64EProvider:
+ v, err := s.Float64()
+ if err == nil {
+ return v, nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i)
+ case float64Provider:
+ return s.Float64(), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i)
+ }
+}
+
+// ToFloat32E casts an interface to a float32 type.
+func ToFloat32E(i interface{}) (float32, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return float32(intv), nil
+ }
+
+ switch s := i.(type) {
+ case float64:
+ return float32(s), nil
+ case float32:
+ return s, nil
+ case int64:
+ return float32(s), nil
+ case int32:
+ return float32(s), nil
+ case int16:
+ return float32(s), nil
+ case int8:
+ return float32(s), nil
+ case uint:
+ return float32(s), nil
+ case uint64:
+ return float32(s), nil
+ case uint32:
+ return float32(s), nil
+ case uint16:
+ return float32(s), nil
+ case uint8:
+ return float32(s), nil
+ case string:
+ v, err := strconv.ParseFloat(s, 32)
+ if err == nil {
+ return float32(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i)
+ case float64EProvider:
+ v, err := s.Float64()
+ if err == nil {
+ return float32(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i)
+ case float64Provider:
+ return float32(s.Float64()), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i)
+ }
+}
+
+// ToInt64E casts an interface to an int64 type.
+func ToInt64E(i interface{}) (int64, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return int64(intv), nil
+ }
+
+ switch s := i.(type) {
+ case int64:
+ return s, nil
+ case int32:
+ return int64(s), nil
+ case int16:
+ return int64(s), nil
+ case int8:
+ return int64(s), nil
+ case uint:
+ return int64(s), nil
+ case uint64:
+ return int64(s), nil
+ case uint32:
+ return int64(s), nil
+ case uint16:
+ return int64(s), nil
+ case uint8:
+ return int64(s), nil
+ case float64:
+ return int64(s), nil
+ case float32:
+ return int64(s), nil
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ return v, nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i)
+ case json.Number:
+ return ToInt64E(string(s))
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i)
+ }
+}
+
+// ToInt32E casts an interface to an int32 type.
+func ToInt32E(i interface{}) (int32, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return int32(intv), nil
+ }
+
+ switch s := i.(type) {
+ case int64:
+ return int32(s), nil
+ case int32:
+ return s, nil
+ case int16:
+ return int32(s), nil
+ case int8:
+ return int32(s), nil
+ case uint:
+ return int32(s), nil
+ case uint64:
+ return int32(s), nil
+ case uint32:
+ return int32(s), nil
+ case uint16:
+ return int32(s), nil
+ case uint8:
+ return int32(s), nil
+ case float64:
+ return int32(s), nil
+ case float32:
+ return int32(s), nil
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ return int32(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i)
+ case json.Number:
+ return ToInt32E(string(s))
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i)
+ }
+}
+
+// ToInt16E casts an interface to an int16 type.
+func ToInt16E(i interface{}) (int16, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return int16(intv), nil
+ }
+
+ switch s := i.(type) {
+ case int64:
+ return int16(s), nil
+ case int32:
+ return int16(s), nil
+ case int16:
+ return s, nil
+ case int8:
+ return int16(s), nil
+ case uint:
+ return int16(s), nil
+ case uint64:
+ return int16(s), nil
+ case uint32:
+ return int16(s), nil
+ case uint16:
+ return int16(s), nil
+ case uint8:
+ return int16(s), nil
+ case float64:
+ return int16(s), nil
+ case float32:
+ return int16(s), nil
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ return int16(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i)
+ case json.Number:
+ return ToInt16E(string(s))
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i)
+ }
+}
+
+// ToInt8E casts an interface to an int8 type.
+func ToInt8E(i interface{}) (int8, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return int8(intv), nil
+ }
+
+ switch s := i.(type) {
+ case int64:
+ return int8(s), nil
+ case int32:
+ return int8(s), nil
+ case int16:
+ return int8(s), nil
+ case int8:
+ return s, nil
+ case uint:
+ return int8(s), nil
+ case uint64:
+ return int8(s), nil
+ case uint32:
+ return int8(s), nil
+ case uint16:
+ return int8(s), nil
+ case uint8:
+ return int8(s), nil
+ case float64:
+ return int8(s), nil
+ case float32:
+ return int8(s), nil
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ return int8(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i)
+ case json.Number:
+ return ToInt8E(string(s))
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i)
+ }
+}
+
+// ToIntE casts an interface to an int type.
+func ToIntE(i interface{}) (int, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ return intv, nil
+ }
+
+ switch s := i.(type) {
+ case int64:
+ return int(s), nil
+ case int32:
+ return int(s), nil
+ case int16:
+ return int(s), nil
+ case int8:
+ return int(s), nil
+ case uint:
+ return int(s), nil
+ case uint64:
+ return int(s), nil
+ case uint32:
+ return int(s), nil
+ case uint16:
+ return int(s), nil
+ case uint8:
+ return int(s), nil
+ case float64:
+ return int(s), nil
+ case float32:
+ return int(s), nil
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ return int(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i)
+ case json.Number:
+ return ToIntE(string(s))
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to int", i, i)
+ }
+}
+
+// ToUintE casts an interface to a uint type.
+func ToUintE(i interface{}) (uint, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ if intv < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(intv), nil
+ }
+
+ switch s := i.(type) {
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ if v < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint", i, i)
+ case json.Number:
+ return ToUintE(string(s))
+ case int64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case int32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case int16:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case int8:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case uint:
+ return s, nil
+ case uint64:
+ return uint(s), nil
+ case uint32:
+ return uint(s), nil
+ case uint16:
+ return uint(s), nil
+ case uint8:
+ return uint(s), nil
+ case float64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case float32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint(s), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint", i, i)
+ }
+}
+
+// ToUint64E casts an interface to a uint64 type.
+func ToUint64E(i interface{}) (uint64, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ if intv < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(intv), nil
+ }
+
+ switch s := i.(type) {
+ case string:
+ v, err := strconv.ParseUint(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ if v < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return v, nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint64", i, i)
+ case json.Number:
+ return ToUint64E(string(s))
+ case int64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case int32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case int16:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case int8:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case uint:
+ return uint64(s), nil
+ case uint64:
+ return s, nil
+ case uint32:
+ return uint64(s), nil
+ case uint16:
+ return uint64(s), nil
+ case uint8:
+ return uint64(s), nil
+ case float32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case float64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint64(s), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint64", i, i)
+ }
+}
+
+// ToUint32E casts an interface to a uint32 type.
+func ToUint32E(i interface{}) (uint32, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ if intv < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(intv), nil
+ }
+
+ switch s := i.(type) {
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ if v < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint32", i, i)
+ case json.Number:
+ return ToUint32E(string(s))
+ case int64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case int32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case int16:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case int8:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case uint:
+ return uint32(s), nil
+ case uint64:
+ return uint32(s), nil
+ case uint32:
+ return s, nil
+ case uint16:
+ return uint32(s), nil
+ case uint8:
+ return uint32(s), nil
+ case float64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case float32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint32(s), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint32", i, i)
+ }
+}
+
+// ToUint16E casts an interface to a uint16 type.
+func ToUint16E(i interface{}) (uint16, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ if intv < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(intv), nil
+ }
+
+ switch s := i.(type) {
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ if v < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint16", i, i)
+ case json.Number:
+ return ToUint16E(string(s))
+ case int64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case int32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case int16:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case int8:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case uint:
+ return uint16(s), nil
+ case uint64:
+ return uint16(s), nil
+ case uint32:
+ return uint16(s), nil
+ case uint16:
+ return s, nil
+ case uint8:
+ return uint16(s), nil
+ case float64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case float32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint16(s), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint16", i, i)
+ }
+}
+
+// ToUint8E casts an interface to a uint type.
+func ToUint8E(i interface{}) (uint8, error) {
+ i = indirect(i)
+
+ intv, ok := toInt(i)
+ if ok {
+ if intv < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(intv), nil
+ }
+
+ switch s := i.(type) {
+ case string:
+ v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0)
+ if err == nil {
+ if v < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(v), nil
+ }
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint8", i, i)
+ case json.Number:
+ return ToUint8E(string(s))
+ case int64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case int32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case int16:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case int8:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case uint:
+ return uint8(s), nil
+ case uint64:
+ return uint8(s), nil
+ case uint32:
+ return uint8(s), nil
+ case uint16:
+ return uint8(s), nil
+ case uint8:
+ return s, nil
+ case float64:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case float32:
+ if s < 0 {
+ return 0, errNegativeNotAllowed
+ }
+ return uint8(s), nil
+ case bool:
+ if s {
+ return 1, nil
+ }
+ return 0, nil
+ case nil:
+ return 0, nil
+ default:
+ return 0, fmt.Errorf("unable to cast %#v of type %T to uint8", i, i)
+ }
+}
+
+// From html/template/content.go
+// Copyright 2011 The Go Authors. All rights reserved.
+// indirect returns the value, after dereferencing as many times
+// as necessary to reach the base type (or nil).
+func indirect(a interface{}) interface{} {
+ if a == nil {
+ return nil
+ }
+ if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr {
+ // Avoid creating a reflect.Value if it's not a pointer.
+ return a
+ }
+ v := reflect.ValueOf(a)
+ for v.Kind() == reflect.Ptr && !v.IsNil() {
+ v = v.Elem()
+ }
+ return v.Interface()
+}
+
+// From html/template/content.go
+// Copyright 2011 The Go Authors. All rights reserved.
+// indirectToStringerOrError returns the value, after dereferencing as many times
+// as necessary to reach the base type (or nil) or an implementation of fmt.Stringer
+// or error,
+func indirectToStringerOrError(a interface{}) interface{} {
+ if a == nil {
+ return nil
+ }
+
+ errorType := reflect.TypeOf((*error)(nil)).Elem()
+ fmtStringerType := reflect.TypeOf((*fmt.Stringer)(nil)).Elem()
+
+ v := reflect.ValueOf(a)
+ for !v.Type().Implements(fmtStringerType) && !v.Type().Implements(errorType) && v.Kind() == reflect.Ptr && !v.IsNil() {
+ v = v.Elem()
+ }
+ return v.Interface()
+}
+
+// ToStringE casts an interface to a string type.
+func ToStringE(i interface{}) (string, error) {
+ i = indirectToStringerOrError(i)
+
+ switch s := i.(type) {
+ case string:
+ return s, nil
+ case bool:
+ return strconv.FormatBool(s), nil
+ case float64:
+ return strconv.FormatFloat(s, 'f', -1, 64), nil
+ case float32:
+ return strconv.FormatFloat(float64(s), 'f', -1, 32), nil
+ case int:
+ return strconv.Itoa(s), nil
+ case int64:
+ return strconv.FormatInt(s, 10), nil
+ case int32:
+ return strconv.Itoa(int(s)), nil
+ case int16:
+ return strconv.FormatInt(int64(s), 10), nil
+ case int8:
+ return strconv.FormatInt(int64(s), 10), nil
+ case uint:
+ return strconv.FormatUint(uint64(s), 10), nil
+ case uint64:
+ return strconv.FormatUint(uint64(s), 10), nil
+ case uint32:
+ return strconv.FormatUint(uint64(s), 10), nil
+ case uint16:
+ return strconv.FormatUint(uint64(s), 10), nil
+ case uint8:
+ return strconv.FormatUint(uint64(s), 10), nil
+ case json.Number:
+ return s.String(), nil
+ case []byte:
+ return string(s), nil
+ case template.HTML:
+ return string(s), nil
+ case template.URL:
+ return string(s), nil
+ case template.JS:
+ return string(s), nil
+ case template.CSS:
+ return string(s), nil
+ case template.HTMLAttr:
+ return string(s), nil
+ case nil:
+ return "", nil
+ case fmt.Stringer:
+ return s.String(), nil
+ case error:
+ return s.Error(), nil
+ default:
+ return "", fmt.Errorf("unable to cast %#v of type %T to string", i, i)
+ }
+}
+
+// ToStringMapStringE casts an interface to a map[string]string type.
+func ToStringMapStringE(i interface{}) (map[string]string, error) {
+ m := map[string]string{}
+
+ switch v := i.(type) {
+ case map[string]string:
+ return v, nil
+ case map[string]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToString(val)
+ }
+ return m, nil
+ case map[interface{}]string:
+ for k, val := range v {
+ m[ToString(k)] = ToString(val)
+ }
+ return m, nil
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToString(val)
+ }
+ return m, nil
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ default:
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]string", i, i)
+ }
+}
+
+// ToStringMapStringSliceE casts an interface to a map[string][]string type.
+func ToStringMapStringSliceE(i interface{}) (map[string][]string, error) {
+ m := map[string][]string{}
+
+ switch v := i.(type) {
+ case map[string][]string:
+ return v, nil
+ case map[string][]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToStringSlice(val)
+ }
+ return m, nil
+ case map[string]string:
+ for k, val := range v {
+ m[ToString(k)] = []string{val}
+ }
+ case map[string]interface{}:
+ for k, val := range v {
+ switch vt := val.(type) {
+ case []interface{}:
+ m[ToString(k)] = ToStringSlice(vt)
+ case []string:
+ m[ToString(k)] = vt
+ default:
+ m[ToString(k)] = []string{ToString(val)}
+ }
+ }
+ return m, nil
+ case map[interface{}][]string:
+ for k, val := range v {
+ m[ToString(k)] = ToStringSlice(val)
+ }
+ return m, nil
+ case map[interface{}]string:
+ for k, val := range v {
+ m[ToString(k)] = ToStringSlice(val)
+ }
+ return m, nil
+ case map[interface{}][]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToStringSlice(val)
+ }
+ return m, nil
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ key, err := ToStringE(k)
+ if err != nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i)
+ }
+ value, err := ToStringSliceE(val)
+ if err != nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i)
+ }
+ m[key] = value
+ }
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ default:
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i)
+ }
+ return m, nil
+}
+
+// ToStringMapBoolE casts an interface to a map[string]bool type.
+func ToStringMapBoolE(i interface{}) (map[string]bool, error) {
+ m := map[string]bool{}
+
+ switch v := i.(type) {
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToBool(val)
+ }
+ return m, nil
+ case map[string]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToBool(val)
+ }
+ return m, nil
+ case map[string]bool:
+ return v, nil
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ default:
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]bool", i, i)
+ }
+}
+
+// ToStringMapE casts an interface to a map[string]interface{} type.
+func ToStringMapE(i interface{}) (map[string]interface{}, error) {
+ m := map[string]interface{}{}
+
+ switch v := i.(type) {
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = val
+ }
+ return m, nil
+ case map[string]interface{}:
+ return v, nil
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ default:
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]interface{}", i, i)
+ }
+}
+
+// ToStringMapIntE casts an interface to a map[string]int{} type.
+func ToStringMapIntE(i interface{}) (map[string]int, error) {
+ m := map[string]int{}
+ if i == nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i)
+ }
+
+ switch v := i.(type) {
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToInt(val)
+ }
+ return m, nil
+ case map[string]interface{}:
+ for k, val := range v {
+ m[k] = ToInt(val)
+ }
+ return m, nil
+ case map[string]int:
+ return v, nil
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ }
+
+ if reflect.TypeOf(i).Kind() != reflect.Map {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i)
+ }
+
+ mVal := reflect.ValueOf(m)
+ v := reflect.ValueOf(i)
+ for _, keyVal := range v.MapKeys() {
+ val, err := ToIntE(v.MapIndex(keyVal).Interface())
+ if err != nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i)
+ }
+ mVal.SetMapIndex(keyVal, reflect.ValueOf(val))
+ }
+ return m, nil
+}
+
+// ToStringMapInt64E casts an interface to a map[string]int64{} type.
+func ToStringMapInt64E(i interface{}) (map[string]int64, error) {
+ m := map[string]int64{}
+ if i == nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i)
+ }
+
+ switch v := i.(type) {
+ case map[interface{}]interface{}:
+ for k, val := range v {
+ m[ToString(k)] = ToInt64(val)
+ }
+ return m, nil
+ case map[string]interface{}:
+ for k, val := range v {
+ m[k] = ToInt64(val)
+ }
+ return m, nil
+ case map[string]int64:
+ return v, nil
+ case string:
+ err := jsonStringToObject(v, &m)
+ return m, err
+ }
+
+ if reflect.TypeOf(i).Kind() != reflect.Map {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i)
+ }
+ mVal := reflect.ValueOf(m)
+ v := reflect.ValueOf(i)
+ for _, keyVal := range v.MapKeys() {
+ val, err := ToInt64E(v.MapIndex(keyVal).Interface())
+ if err != nil {
+ return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i)
+ }
+ mVal.SetMapIndex(keyVal, reflect.ValueOf(val))
+ }
+ return m, nil
+}
+
+// ToSliceE casts an interface to a []interface{} type.
+func ToSliceE(i interface{}) ([]interface{}, error) {
+ var s []interface{}
+
+ switch v := i.(type) {
+ case []interface{}:
+ return append(s, v...), nil
+ case []map[string]interface{}:
+ for _, u := range v {
+ s = append(s, u)
+ }
+ return s, nil
+ default:
+ return s, fmt.Errorf("unable to cast %#v of type %T to []interface{}", i, i)
+ }
+}
+
+// ToBoolSliceE casts an interface to a []bool type.
+func ToBoolSliceE(i interface{}) ([]bool, error) {
+ if i == nil {
+ return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i)
+ }
+
+ switch v := i.(type) {
+ case []bool:
+ return v, nil
+ }
+
+ kind := reflect.TypeOf(i).Kind()
+ switch kind {
+ case reflect.Slice, reflect.Array:
+ s := reflect.ValueOf(i)
+ a := make([]bool, s.Len())
+ for j := 0; j < s.Len(); j++ {
+ val, err := ToBoolE(s.Index(j).Interface())
+ if err != nil {
+ return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i)
+ }
+ a[j] = val
+ }
+ return a, nil
+ default:
+ return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i)
+ }
+}
+
+// ToStringSliceE casts an interface to a []string type.
+func ToStringSliceE(i interface{}) ([]string, error) {
+ var a []string
+
+ switch v := i.(type) {
+ case []interface{}:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []string:
+ return v, nil
+ case []int8:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []int:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []int32:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []int64:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []float32:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case []float64:
+ for _, u := range v {
+ a = append(a, ToString(u))
+ }
+ return a, nil
+ case string:
+ return strings.Fields(v), nil
+ case []error:
+ for _, err := range i.([]error) {
+ a = append(a, err.Error())
+ }
+ return a, nil
+ case interface{}:
+ str, err := ToStringE(v)
+ if err != nil {
+ return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i)
+ }
+ return []string{str}, nil
+ default:
+ return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i)
+ }
+}
+
+// ToIntSliceE casts an interface to a []int type.
+func ToIntSliceE(i interface{}) ([]int, error) {
+ if i == nil {
+ return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i)
+ }
+
+ switch v := i.(type) {
+ case []int:
+ return v, nil
+ }
+
+ kind := reflect.TypeOf(i).Kind()
+ switch kind {
+ case reflect.Slice, reflect.Array:
+ s := reflect.ValueOf(i)
+ a := make([]int, s.Len())
+ for j := 0; j < s.Len(); j++ {
+ val, err := ToIntE(s.Index(j).Interface())
+ if err != nil {
+ return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i)
+ }
+ a[j] = val
+ }
+ return a, nil
+ default:
+ return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i)
+ }
+}
+
+// ToDurationSliceE casts an interface to a []time.Duration type.
+func ToDurationSliceE(i interface{}) ([]time.Duration, error) {
+ if i == nil {
+ return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i)
+ }
+
+ switch v := i.(type) {
+ case []time.Duration:
+ return v, nil
+ }
+
+ kind := reflect.TypeOf(i).Kind()
+ switch kind {
+ case reflect.Slice, reflect.Array:
+ s := reflect.ValueOf(i)
+ a := make([]time.Duration, s.Len())
+ for j := 0; j < s.Len(); j++ {
+ val, err := ToDurationE(s.Index(j).Interface())
+ if err != nil {
+ return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i)
+ }
+ a[j] = val
+ }
+ return a, nil
+ default:
+ return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i)
+ }
+}
+
+// StringToDate attempts to parse a string into a time.Time type using a
+// predefined list of formats. If no suitable format is found, an error is
+// returned.
+func StringToDate(s string) (time.Time, error) {
+ return parseDateWith(s, time.UTC, timeFormats)
+}
+
+// StringToDateInDefaultLocation casts an empty interface to a time.Time,
+// interpreting inputs without a timezone to be in the given location,
+// or the local timezone if nil.
+func StringToDateInDefaultLocation(s string, location *time.Location) (time.Time, error) {
+ return parseDateWith(s, location, timeFormats)
+}
+
+type timeFormatType int
+
+const (
+ timeFormatNoTimezone timeFormatType = iota
+ timeFormatNamedTimezone
+ timeFormatNumericTimezone
+ timeFormatNumericAndNamedTimezone
+ timeFormatTimeOnly
+)
+
+type timeFormat struct {
+ format string
+ typ timeFormatType
+}
+
+func (f timeFormat) hasTimezone() bool {
+ // We don't include the formats with only named timezones, see
+ // https://github.com/golang/go/issues/19694#issuecomment-289103522
+ return f.typ >= timeFormatNumericTimezone && f.typ <= timeFormatNumericAndNamedTimezone
+}
+
+var timeFormats = []timeFormat{
+ // Keep common formats at the top.
+ {"2006-01-02", timeFormatNoTimezone},
+ {time.RFC3339, timeFormatNumericTimezone},
+ {"2006-01-02T15:04:05", timeFormatNoTimezone}, // iso8601 without timezone
+ {time.RFC1123Z, timeFormatNumericTimezone},
+ {time.RFC1123, timeFormatNamedTimezone},
+ {time.RFC822Z, timeFormatNumericTimezone},
+ {time.RFC822, timeFormatNamedTimezone},
+ {time.RFC850, timeFormatNamedTimezone},
+ {"2006-01-02 15:04:05.999999999 -0700 MST", timeFormatNumericAndNamedTimezone}, // Time.String()
+ {"2006-01-02T15:04:05-0700", timeFormatNumericTimezone}, // RFC3339 without timezone hh:mm colon
+ {"2006-01-02 15:04:05Z0700", timeFormatNumericTimezone}, // RFC3339 without T or timezone hh:mm colon
+ {"2006-01-02 15:04:05", timeFormatNoTimezone},
+ {time.ANSIC, timeFormatNoTimezone},
+ {time.UnixDate, timeFormatNamedTimezone},
+ {time.RubyDate, timeFormatNumericTimezone},
+ {"2006-01-02 15:04:05Z07:00", timeFormatNumericTimezone},
+ {"02 Jan 2006", timeFormatNoTimezone},
+ {"2006-01-02 15:04:05 -07:00", timeFormatNumericTimezone},
+ {"2006-01-02 15:04:05 -0700", timeFormatNumericTimezone},
+ {time.Kitchen, timeFormatTimeOnly},
+ {time.Stamp, timeFormatTimeOnly},
+ {time.StampMilli, timeFormatTimeOnly},
+ {time.StampMicro, timeFormatTimeOnly},
+ {time.StampNano, timeFormatTimeOnly},
+}
+
+func parseDateWith(s string, location *time.Location, formats []timeFormat) (d time.Time, e error) {
+ for _, format := range formats {
+ if d, e = time.Parse(format.format, s); e == nil {
+
+ // Some time formats have a zone name, but no offset, so it gets
+ // put in that zone name (not the default one passed in to us), but
+ // without that zone's offset. So set the location manually.
+ if format.typ <= timeFormatNamedTimezone {
+ if location == nil {
+ location = time.Local
+ }
+ year, month, day := d.Date()
+ hour, min, sec := d.Clock()
+ d = time.Date(year, month, day, hour, min, sec, d.Nanosecond(), location)
+ }
+
+ return
+ }
+ }
+ return d, fmt.Errorf("unable to parse date: %s", s)
+}
+
+// jsonStringToObject attempts to unmarshall a string as JSON into
+// the object passed as pointer.
+func jsonStringToObject(s string, v interface{}) error {
+ data := []byte(s)
+ return json.Unmarshal(data, v)
+}
+
+// toInt returns the int value of v if v or v's underlying type
+// is an int.
+// Note that this will return false for int64 etc. types.
+func toInt(v interface{}) (int, bool) {
+ switch v := v.(type) {
+ case int:
+ return v, true
+ case time.Weekday:
+ return int(v), true
+ case time.Month:
+ return int(v), true
+ default:
+ return 0, false
+ }
+}
+
+func trimZeroDecimal(s string) string {
+ var foundZero bool
+ for i := len(s); i > 0; i-- {
+ switch s[i-1] {
+ case '.':
+ if foundZero {
+ return s[:i-1]
+ }
+ case '0':
+ foundZero = true
+ default:
+ return s
+ }
+ }
+ return s
+}
@@ -0,0 +1,27 @@
+// Code generated by "stringer -type timeFormatType"; DO NOT EDIT.
+
+package cast
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[timeFormatNoTimezone-0]
+ _ = x[timeFormatNamedTimezone-1]
+ _ = x[timeFormatNumericTimezone-2]
+ _ = x[timeFormatNumericAndNamedTimezone-3]
+ _ = x[timeFormatTimeOnly-4]
+}
+
+const _timeFormatType_name = "timeFormatNoTimezonetimeFormatNamedTimezonetimeFormatNumericTimezonetimeFormatNumericAndNamedTimezonetimeFormatTimeOnly"
+
+var _timeFormatType_index = [...]uint8{0, 20, 43, 68, 101, 119}
+
+func (i timeFormatType) String() string {
+ if i < 0 || i >= timeFormatType(len(_timeFormatType_index)-1) {
+ return "timeFormatType(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _timeFormatType_name[_timeFormatType_index[i]:_timeFormatType_index[i+1]]
+}
@@ -261,7 +261,7 @@ github.com/charmbracelet/bubbles/v2/table
github.com/charmbracelet/bubbles/v2/textarea
github.com/charmbracelet/bubbles/v2/textinput
github.com/charmbracelet/bubbles/v2/viewport
-# github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1 => github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250702164605-a991b583c0e7
+# github.com/charmbracelet/bubbletea/v2 v2.0.0-beta.1 => github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250703182356-a42fb608faaf
## explicit; go 1.24.3
github.com/charmbracelet/bubbletea/v2
# github.com/charmbracelet/colorprofile v0.3.1
@@ -276,12 +276,12 @@ github.com/charmbracelet/glamour/v2
github.com/charmbracelet/glamour/v2/ansi
github.com/charmbracelet/glamour/v2/internal/autolink
github.com/charmbracelet/glamour/v2/styles
-# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1.0.20250513162854-28902d027c40 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250702164623-bd5b9da8d487
+# github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.2.0.20250703152125-8e1c474f8a71 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250703152138-ff346e83e819
## explicit; go 1.24.2
github.com/charmbracelet/lipgloss/v2
github.com/charmbracelet/lipgloss/v2/table
github.com/charmbracelet/lipgloss/v2/tree
-# github.com/charmbracelet/ultraviolet v0.0.0-20250702190342-c2f25359be42
+# github.com/charmbracelet/ultraviolet v0.0.0-20250707134318-0fdaa64b8c5e
## explicit; go 1.24.0
github.com/charmbracelet/ultraviolet
# github.com/charmbracelet/x/ansi v0.9.3
@@ -410,10 +410,13 @@ github.com/lucasb-eyer/go-colorful
## explicit; go 1.12
github.com/mailru/easyjson/buffer
github.com/mailru/easyjson/jwriter
-# github.com/mark3labs/mcp-go v0.17.0
+# github.com/mark3labs/mcp-go v0.32.0
## explicit; go 1.23
github.com/mark3labs/mcp-go/client
+github.com/mark3labs/mcp-go/client/transport
github.com/mark3labs/mcp-go/mcp
+github.com/mark3labs/mcp-go/server
+github.com/mark3labs/mcp-go/util
# github.com/mattn/go-isatty v0.0.20
## explicit; go 1.15
github.com/mattn/go-isatty
@@ -461,7 +464,7 @@ github.com/ncruces/julianday
# github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
## explicit
github.com/nfnt/resize
-# github.com/openai/openai-go v0.1.0-beta.2
+# github.com/openai/openai-go v1.8.2
## explicit; go 1.21
github.com/openai/openai-go
github.com/openai/openai-go/azure
@@ -473,16 +476,17 @@ github.com/openai/openai-go/internal/apiquery
github.com/openai/openai-go/internal/encoding/json
github.com/openai/openai-go/internal/encoding/json/sentinel
github.com/openai/openai-go/internal/encoding/json/shims
-github.com/openai/openai-go/internal/param
+github.com/openai/openai-go/internal/paramutil
github.com/openai/openai-go/internal/requestconfig
github.com/openai/openai-go/option
github.com/openai/openai-go/packages/pagination
github.com/openai/openai-go/packages/param
-github.com/openai/openai-go/packages/resp
+github.com/openai/openai-go/packages/respjson
github.com/openai/openai-go/packages/ssestream
github.com/openai/openai-go/responses
github.com/openai/openai-go/shared
github.com/openai/openai-go/shared/constant
+github.com/openai/openai-go/webhooks
# github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
## explicit; go 1.14
github.com/pkg/browser
@@ -513,6 +517,9 @@ github.com/sahilm/fuzzy
# github.com/sethvargo/go-retry v0.3.0
## explicit; go 1.21
github.com/sethvargo/go-retry
+# github.com/spf13/cast v1.7.1
+## explicit; go 1.19
+github.com/spf13/cast
# github.com/spf13/cobra v1.9.1
## explicit; go 1.15
github.com/spf13/cobra
@@ -817,5 +824,5 @@ mvdan.cc/sh/v3/fileutil
mvdan.cc/sh/v3/interp
mvdan.cc/sh/v3/pattern
mvdan.cc/sh/v3/syntax
-# github.com/charmbracelet/bubbletea/v2 => github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250702164605-a991b583c0e7
-# github.com/charmbracelet/lipgloss/v2 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250702164623-bd5b9da8d487
+# github.com/charmbracelet/bubbletea/v2 => github.com/charmbracelet/bubbletea-internal/v2 v2.0.0-20250703182356-a42fb608faaf
+# github.com/charmbracelet/lipgloss/v2 => github.com/charmbracelet/lipgloss-internal/v2 v2.0.0-20250703152138-ff346e83e819