Documentation
¶
Overview ¶
Package gather provides the Gather layer for the Zerfoo ML framework.
Index ¶
- func BuildGather[T tensor.Numeric](engine compute.Engine[T], _ numeric.Arithmetic[T], name string, ...) (graph.Node[T], error)
- type Gather
- func (g *Gather[T]) Attributes() map[string]interface{}
- func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, ...) ([]*tensor.TensorNumeric[T], error)
- func (g *Gather[T]) EmbeddedFrozen() []*tensor.TensorNumeric[T]
- func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
- func (g *Gather[T]) HasEmbeddedWeights() bool
- func (g *Gather[T]) OpType() string
- func (g *Gather[T]) OutputShape() []int
- func (g *Gather[T]) Parameters() []*graph.Parameter[T]
Constants ¶
This section is empty.
Variables ¶
This section is empty.
Functions ¶
func BuildGather ¶
func BuildGather[T tensor.Numeric]( engine compute.Engine[T], _ numeric.Arithmetic[T], name string, params map[string]*graph.Parameter[T], attrs map[string]interface{}, ) (graph.Node[T], error)
BuildGather constructs a Gather layer. For embedding-style nodes whose name maps to a known weight parameter, weights are embedded in the layer. For "gather from shape" nodes where the indices are constant, the indices are embedded in the layer. All other Gather nodes operate as general ONNX Gather (axis-0 indexing).
Types ¶
type Gather ¶
Gather is a layer that gathers slices from a tensor.
func NewWithIndices ¶ added in v0.2.1
func NewWithIndices[T tensor.Numeric](engine compute.Engine[T], indices *tensor.TensorNumeric[int]) *Gather[T]
NewWithIndices creates a new Gather layer with embedded constant indices. At forward time, input[0] is the data tensor; indices come from the layer.
func NewWithWeights ¶
func NewWithWeights[T tensor.Numeric](engine compute.Engine[T], weights *tensor.TensorNumeric[T]) *Gather[T]
NewWithWeights creates a new Gather layer with embedded weights.
func (*Gather[T]) Attributes ¶
Attributes returns nil for the Gather layer.
func (*Gather[T]) Backward ¶
func (g *Gather[T]) Backward(ctx context.Context, mode types.BackwardMode, outputGradient *tensor.TensorNumeric[T], inputs ...*tensor.TensorNumeric[T]) ([]*tensor.TensorNumeric[T], error)
Backward computes the gradients for the Gather layer.
func (*Gather[T]) EmbeddedFrozen ¶ added in v0.2.1
func (g *Gather[T]) EmbeddedFrozen() []*tensor.TensorNumeric[T]
EmbeddedFrozen returns the embedded frozen tensors (weights) that should be registered as frozen slots during compilation. Returns nil if no weights are embedded. Implements graph.EmbeddedFrozenProvider.
func (*Gather[T]) Forward ¶
func (g *Gather[T]) Forward(ctx context.Context, inputs ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error)
Forward computes the gather operation.
func (*Gather[T]) HasEmbeddedWeights ¶
HasEmbeddedWeights returns true if this Gather layer has embedded weights.
func (*Gather[T]) OutputShape ¶
OutputShape returns the output shape of the Gather layer.
func (*Gather[T]) Parameters ¶
Parameters returns no trainable parameters for the Gather layer.