Skip to content

Vivado bit exact softmax #1225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion hls4ml/backends/fpga/passes/fix_softmax_table_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

class FixSoftmaxTableSize(OptimizerPass):
def match(self, node):
return isinstance(node, Softmax)
if not isinstance(node, Softmax):
return False
if 'inv_table_size' in node.attributes:
return False # handler generating inv_table_size sets it properly
return True

def transform(self, model, node: Layer):
inp_layer = node.get_input_node() # type: ignore
Expand Down
63 changes: 61 additions & 2 deletions hls4ml/backends/vivado/passes/core_templates.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import ceil, log2

from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
Expand Down Expand Up @@ -152,13 +154,21 @@ def format(self, node):

softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned n_outer = {n_outer};
static const unsigned n_inner = {n_inner};
static const unsigned parallelization_factor = {parallelization_factor};
static const unsigned exp_table_size = {exp_table_size};
static const unsigned inv_table_size = {inv_table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const unsigned axis = {axis};
static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation};
static constexpr float exp_scale = {exp_scale};
typedef {exp_table_t.name} exp_table_t;
typedef {inv_table_t.name} inv_table_t;
typedef {accum_t.name} accum_t;
typedef {inv_inp_t.name} inv_inp_t;
typedef {inp_norm_t_str} inp_norm_t;
}};\n"""

activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
Expand Down Expand Up @@ -210,10 +220,59 @@ def __init__(self):
super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__
self.template = softmax_config_template

def format(self, node):
params = self._default_config_params(node)
params['type'] = node.get_attr('activation')
params.setdefault('exp_table_size', params['table_size'])
params.setdefault('inv_table_size', params['table_size'])
params.setdefault('n_inner', 1)
params.setdefault('n_outer', 1)
params.setdefault('exp_scale', 1.0)
params.setdefault('parallelization_factor', -1)
if params['accum_t'].name == 'model_default_t': # type: ignore
scale = ceil(log2(node.attributes['n_in']))
exp_table_t = node.attributes['exp_table_t'].precision
signed, width, integers = exp_table_t.signed, exp_table_t.width, exp_table_t.integer
params['accum_t_str'] = f'ap_{"" if signed else "u"}fixed<{width + scale}, {integers + scale}>'
else:
params['accum_t_str'] = params['accum_t'].name # type: ignore
if params['inv_inp_t'].name == 'model_default_t': # type: ignore
params['inv_inp_t'] = params['exp_table_t']

if 'inp_norm_t' not in params:
# Only used in stable (max-normalized) implementation
input_t = node.get_input_variable().type.precision
width, iwidth, signed = input_t.width, input_t.integer, input_t.signed # noqa: F841
width, iwidth = width - signed, iwidth - signed
if signed:
# Fix table size if too large
exp_table_size = params['inv_table_size']
params['exp_table_size'] = str(min(int(exp_table_size), 2**width))
params['inp_norm_t_str'] = f'ap_ufixed<{width}, {iwidth}>'
else:
params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore

return self.template.format(**params)


class SoftmaxFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Softmax, include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
params = self._default_function_params(node)
use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1
use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel'
params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim'
params['config'] = f'softmax_config{node.index}'

return self.template.format(**params)


class ActivationFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list)
super().__init__((Activation, HardActivation), include_header=activ_include_list)
self.template = activ_function_template

def format(self, node):
Expand Down
8 changes: 0 additions & 8 deletions hls4ml/backends/vivado/vivado_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
SeparableConv1D,
SeparableConv2D,
SimpleRNN,
Softmax,
)
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType
Expand Down Expand Up @@ -551,13 +550,6 @@ def init_pooling1d(self, layer):
def init_pooling2d(self, layer):
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())

@layer_optimizer(Softmax)
def init_softmax(self, layer):
if layer.model.config.get_config_value('IOType') == 'io_parallel':
assert (
len(layer.get_input_variable().shape) == 1
), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.'

@layer_optimizer(Embedding)
def init_embed(self, layer):
if layer.attributes['n_in'] is None:
Expand Down
224 changes: 224 additions & 0 deletions hls4ml/converters/keras_v3/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import inspect
import typing
from math import prod
from typing import Any, Sequence

import numpy as np

from ._base import KerasV3LayerHandler, register

if typing.TYPE_CHECKING:
import keras
from keras.api import KerasTensor
from keras.src.layers.merging.base_merge import Merge


@register
class KV3DenseHandler(KerasV3LayerHandler):
handles = ('keras.src.layers.core.dense.Dense',)

def handle(
self,
layer: 'keras.layers.Dense',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):

kernel = self.load_weight(layer, 'kernel')
bias = self.load_weight(layer, 'bias') if layer.use_bias else None
n_in, n_out = kernel.shape

config = {
'data_format': 'channels_last',
'weight_data': kernel,
'bias_data': bias,
'n_out': n_out,
'n_in': n_in,
}
return config


@register
class KV3InputHandler(KerasV3LayerHandler):
handles = ('keras.src.layers.core.input_layer.InputLayer',)

def handle(
self,
layer: 'keras.layers.InputLayer',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {'input_shape': list(layer._batch_shape[1:])}
return config


@register
class KV3MergeHandler(KerasV3LayerHandler):
handles = (
'keras.src.layers.merging.add.Add',
'keras.src.layers.merging.multiply.Multiply',
'keras.src.layers.merging.average.Average',
'keras.src.layers.merging.maximum.Maximum',
'keras.src.layers.merging.minimum.Minimum',
'keras.src.layers.merging.concatenate.Concatenate',
'keras.src.layers.merging.subtract.Subtract',
'keras.src.layers.merging.dot.Dot',
)

def handle(
self,
layer: 'Merge',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
cls_name: str | None = None,
):
assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output"
output_shape = list(out_tensors[0].shape[1:])

cls_name = cls_name or layer.__class__.__name__
config: dict[str, Any] = {
'output_shape': output_shape,
'op': cls_name.lower(),
}

match cls_name.lower():
case 'Concatenate':
rank = len(output_shape)
class_name = f'Concatenate{rank}d'
config['axis'] = layer.axis
case 'Dot':
class_name = f'Dot{len(output_shape)}d'
rank = len(output_shape)
assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}"
case _:
class_name = 'Merge'

config['class_name'] = class_name
return config


@register
class KV3ActivationHandler(KerasV3LayerHandler):
handles = ('keras.src.layers.activations.activation.Activation',)

def handle(
self,
layer: 'keras.layers.Activation',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
import keras

config = {}
config.update(self.default_config)

activation = getattr(layer, 'activation', keras.activations.linear)
match activation:
case keras.activations.softmax:
class_name = 'Softmax'
config['axis'] = -1
case keras.activations.hard_sigmoid:
class_name = 'HardActivation'
case keras.activations.leaky_relu:
class_name = 'LeakyReLU'
signature = inspect.signature(keras.activations.leaky_relu)
config['activ_param'] = signature.parameters['negative_slope'].default
case keras.activations.elu:
class_name = 'ELU'
signature = inspect.signature(keras.activations.elu)
config['activ_param'] = signature.parameters['alpha'].default
case _:
class_name = 'Activation'

config['activation'] = activation.__name__
config['class_name'] = class_name
config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore
return (config,)


@register
class KV3ReLUHandler(KerasV3LayerHandler):
handles = (
'keras.src.layers.activations.leaky_relu.LeakyReLU',
'keras.src.layers.activations.prelu.PReLU',
'keras.src.layers.activations.relu.ReLU',
)

def handle(
self,
layer: 'keras.layers.ReLU',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {}
config.update(self.default_config)

if layer.__class__.__name__ == 'ReLU':
config['class_name'] = 'Activation'
config['activation'] = 'relu'
return config

if layer.__class__.__name__ == 'PReLU':
config['class_name'] = 'PReLU'
config['param_data'] = np.array(layer.alpha)
config['activation'] = 'prelu'
else:
config['class_name'] = 'LeakyReLU'
config['activ_param'] = float(layer.negative_slope)
config['activation'] = 'leaky_relu'

return (config,)


@register
class KV3SoftmaxHandler(KerasV3LayerHandler):
handles = ('keras.src.layers.activations.softmax.Softmax',)

def handle(
self,
layer: 'keras.layers.Softmax',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
ax = layer.axis
ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax
# io_stream asserts axis=-1, convert to -1 when it is
n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore
n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore
ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax
config = {}
config.update(self.default_config)
if len(in_tensors) == 2:
raise NotImplementedError("Masked softmax not supported yet")
config['class_name'] = 'MaskedSoftmax'
elif len(in_tensors) == 1:
config['class_name'] = 'Softmax'
else:
raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}")
config['axis'] = layer.axis
config['activation'] = 'softmax'
config['n_outer'] = n_outer
config['n_inner'] = n_inner

return (config,)


@register
class KV3HardActivationHandler(KerasV3LayerHandler):
handles = ('keras.src.layers.activations.elu.ELU',)

def handle(
self,
layer: 'keras.layers.ELU',
in_tensors: Sequence['KerasTensor'],
out_tensors: Sequence['KerasTensor'],
):
config = {}
config.update(self.default_config)

config['class_name'] = 'ELU'
config['activ_param'] = float(layer.alpha)
config['activation'] = 'elu'
config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore

return (config,)
Loading
Loading