diff --git a/hls4ml/backends/fpga/passes/fix_softmax_table_size.py b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py index 4e04626d2e..860aa89597 100644 --- a/hls4ml/backends/fpga/passes/fix_softmax_table_size.py +++ b/hls4ml/backends/fpga/passes/fix_softmax_table_size.py @@ -6,7 +6,11 @@ class FixSoftmaxTableSize(OptimizerPass): def match(self, node): - return isinstance(node, Softmax) + if not isinstance(node, Softmax): + return False + if 'inv_table_size' in node.attributes: + return False # handler generating inv_table_size sets it properly + return True def transform(self, model, node: Layer): inp_layer = node.get_input_node() # type: ignore diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 1393cdfb49..7e18af7fd7 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -1,3 +1,5 @@ +from math import ceil, log2 + from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax @@ -152,13 +154,21 @@ def format(self, node): softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ static const unsigned n_in = {n_in}; - static const unsigned table_size = {table_size}; + static const unsigned n_outer = {n_outer}; + static const unsigned n_inner = {n_inner}; + static const unsigned parallelization_factor = {parallelization_factor}; + static const unsigned exp_table_size = {exp_table_size}; + static const unsigned inv_table_size = {inv_table_size}; static const unsigned io_type = nnet::{iotype}; static const unsigned reuse_factor = {reuse}; static const unsigned axis = {axis}; static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; + static constexpr float exp_scale = {exp_scale}; typedef {exp_table_t.name} exp_table_t; typedef {inv_table_t.name} inv_table_t; + typedef {accum_t.name} accum_t; + typedef {inv_inp_t.name} inv_inp_t; + typedef {inp_norm_t_str} inp_norm_t; }};\n""" activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});' @@ -210,10 +220,59 @@ def __init__(self): super(ActivationConfigTemplate, self).__init__(Softmax) # Skip ActivationConfigTemplate's __init__ self.template = softmax_config_template + def format(self, node): + params = self._default_config_params(node) + params['type'] = node.get_attr('activation') + params.setdefault('exp_table_size', params['table_size']) + params.setdefault('inv_table_size', params['table_size']) + params.setdefault('n_inner', 1) + params.setdefault('n_outer', 1) + params.setdefault('exp_scale', 1.0) + params.setdefault('parallelization_factor', -1) + if params['accum_t'].name == 'model_default_t': # type: ignore + scale = ceil(log2(node.attributes['n_in'])) + exp_table_t = node.attributes['exp_table_t'].precision + signed, width, integers = exp_table_t.signed, exp_table_t.width, exp_table_t.integer + params['accum_t_str'] = f'ap_{"" if signed else "u"}fixed<{width + scale}, {integers + scale}>' + else: + params['accum_t_str'] = params['accum_t'].name # type: ignore + if params['inv_inp_t'].name == 'model_default_t': # type: ignore + params['inv_inp_t'] = params['exp_table_t'] + + if 'inp_norm_t' not in params: + # Only used in stable (max-normalized) implementation + input_t = node.get_input_variable().type.precision + width, iwidth, signed = input_t.width, input_t.integer, input_t.signed # noqa: F841 + width, iwidth = width - signed, iwidth - signed + if signed: + # Fix table size if too large + exp_table_size = params['inv_table_size'] + params['exp_table_size'] = str(min(int(exp_table_size), 2**width)) + params['inp_norm_t_str'] = f'ap_ufixed<{width}, {iwidth}>' + else: + params['inp_norm_t_str'] = params['inp_norm_t'].name # type: ignore + + return self.template.format(**params) + + +class SoftmaxFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(Softmax, include_header=activ_include_list) + self.template = activ_function_template + + def format(self, node): + params = self._default_function_params(node) + use_multidim = node.get_attr('n_inner', 1) > 1 or node.get_attr('n_outer', 1) > 1 + use_multidim = use_multidim and node.model.config.get_config_value('IOType') == 'io_parallel' + params['activation'] = 'softmax' if not use_multidim else 'softmax_multidim' + params['config'] = f'softmax_config{node.index}' + + return self.template.format(**params) + class ActivationFunctionTemplate(FunctionCallTemplate): def __init__(self): - super().__init__((Activation, HardActivation, Softmax), include_header=activ_include_list) + super().__init__((Activation, HardActivation), include_header=activ_include_list) self.template = activ_function_template def format(self, node): diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 117805dd86..d2ba498a73 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -26,7 +26,6 @@ SeparableConv1D, SeparableConv2D, SimpleRNN, - Softmax, ) from hls4ml.model.optimizer import get_backend_passes, layer_optimizer from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType @@ -551,13 +550,6 @@ def init_pooling1d(self, layer): def init_pooling2d(self, layer): layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower()) - @layer_optimizer(Softmax) - def init_softmax(self, layer): - if layer.model.config.get_config_value('IOType') == 'io_parallel': - assert ( - len(layer.get_input_variable().shape) == 1 - ), 'Softmax with io_parallel strategy cannot be used on multidimensional tensors.' - @layer_optimizer(Embedding) def init_embed(self, layer): if layer.attributes['n_in'] is None: diff --git a/hls4ml/converters/keras_v3/core.py b/hls4ml/converters/keras_v3/core.py new file mode 100644 index 0000000000..27dc04d6ab --- /dev/null +++ b/hls4ml/converters/keras_v3/core.py @@ -0,0 +1,224 @@ +import inspect +import typing +from math import prod +from typing import Any, Sequence + +import numpy as np + +from ._base import KerasV3LayerHandler, register + +if typing.TYPE_CHECKING: + import keras + from keras.api import KerasTensor + from keras.src.layers.merging.base_merge import Merge + + +@register +class KV3DenseHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.dense.Dense',) + + def handle( + self, + layer: 'keras.layers.Dense', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + + kernel = self.load_weight(layer, 'kernel') + bias = self.load_weight(layer, 'bias') if layer.use_bias else None + n_in, n_out = kernel.shape + + config = { + 'data_format': 'channels_last', + 'weight_data': kernel, + 'bias_data': bias, + 'n_out': n_out, + 'n_in': n_in, + } + return config + + +@register +class KV3InputHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.core.input_layer.InputLayer',) + + def handle( + self, + layer: 'keras.layers.InputLayer', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {'input_shape': list(layer._batch_shape[1:])} + return config + + +@register +class KV3MergeHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.merging.add.Add', + 'keras.src.layers.merging.multiply.Multiply', + 'keras.src.layers.merging.average.Average', + 'keras.src.layers.merging.maximum.Maximum', + 'keras.src.layers.merging.minimum.Minimum', + 'keras.src.layers.merging.concatenate.Concatenate', + 'keras.src.layers.merging.subtract.Subtract', + 'keras.src.layers.merging.dot.Dot', + ) + + def handle( + self, + layer: 'Merge', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + cls_name: str | None = None, + ): + assert len(out_tensors) == 1, f"Merge layer {layer.name} has more than one output" + output_shape = list(out_tensors[0].shape[1:]) + + cls_name = cls_name or layer.__class__.__name__ + config: dict[str, Any] = { + 'output_shape': output_shape, + 'op': cls_name.lower(), + } + + match cls_name.lower(): + case 'Concatenate': + rank = len(output_shape) + class_name = f'Concatenate{rank}d' + config['axis'] = layer.axis + case 'Dot': + class_name = f'Dot{len(output_shape)}d' + rank = len(output_shape) + assert rank == 1, f"Dot product only supported for 1D tensors, got {rank}D on layer {layer.name}" + case _: + class_name = 'Merge' + + config['class_name'] = class_name + return config + + +@register +class KV3ActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.activation.Activation',) + + def handle( + self, + layer: 'keras.layers.Activation', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + import keras + + config = {} + config.update(self.default_config) + + activation = getattr(layer, 'activation', keras.activations.linear) + match activation: + case keras.activations.softmax: + class_name = 'Softmax' + config['axis'] = -1 + case keras.activations.hard_sigmoid: + class_name = 'HardActivation' + case keras.activations.leaky_relu: + class_name = 'LeakyReLU' + signature = inspect.signature(keras.activations.leaky_relu) + config['activ_param'] = signature.parameters['negative_slope'].default + case keras.activations.elu: + class_name = 'ELU' + signature = inspect.signature(keras.activations.elu) + config['activ_param'] = signature.parameters['alpha'].default + case _: + class_name = 'Activation' + + config['activation'] = activation.__name__ + config['class_name'] = class_name + config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore + return (config,) + + +@register +class KV3ReLUHandler(KerasV3LayerHandler): + handles = ( + 'keras.src.layers.activations.leaky_relu.LeakyReLU', + 'keras.src.layers.activations.prelu.PReLU', + 'keras.src.layers.activations.relu.ReLU', + ) + + def handle( + self, + layer: 'keras.layers.ReLU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + if layer.__class__.__name__ == 'ReLU': + config['class_name'] = 'Activation' + config['activation'] = 'relu' + return config + + if layer.__class__.__name__ == 'PReLU': + config['class_name'] = 'PReLU' + config['param_data'] = np.array(layer.alpha) + config['activation'] = 'prelu' + else: + config['class_name'] = 'LeakyReLU' + config['activ_param'] = float(layer.negative_slope) + config['activation'] = 'leaky_relu' + + return (config,) + + +@register +class KV3SoftmaxHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.softmax.Softmax',) + + def handle( + self, + layer: 'keras.layers.Softmax', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + ax = layer.axis + ax = ax if ax >= 0 else len(in_tensors[0].shape) + ax + # io_stream asserts axis=-1, convert to -1 when it is + n_outer: int = prod(in_tensors[0].shape[1:ax]) # type: ignore + n_inner: int = prod(in_tensors[0].shape[ax + 1 :]) # type: ignore + ax = -1 if ax == len(in_tensors[0].shape) - 1 else ax + config = {} + config.update(self.default_config) + if len(in_tensors) == 2: + raise NotImplementedError("Masked softmax not supported yet") + config['class_name'] = 'MaskedSoftmax' + elif len(in_tensors) == 1: + config['class_name'] = 'Softmax' + else: + raise ValueError(f"Too many inputs for softmax layer {layer.name}: expected 1 or 2, got {len(in_tensors)}") + config['axis'] = layer.axis + config['activation'] = 'softmax' + config['n_outer'] = n_outer + config['n_inner'] = n_inner + + return (config,) + + +@register +class KV3HardActivationHandler(KerasV3LayerHandler): + handles = ('keras.src.layers.activations.elu.ELU',) + + def handle( + self, + layer: 'keras.layers.ELU', + in_tensors: Sequence['KerasTensor'], + out_tensors: Sequence['KerasTensor'], + ): + config = {} + config.update(self.default_config) + + config['class_name'] = 'ELU' + config['activ_param'] = float(layer.alpha) + config['activation'] = 'elu' + config['n_in'] = prod(in_tensors[0].shape[1:]) # type: ignore + + return (config,) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 03e3d9ce8a..ef438cb80e 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1,4 +1,5 @@ import typing +from copy import copy import numpy as np @@ -21,6 +22,9 @@ FixedPrecisionType, IntegerPrecisionType, NamedType, + PrecisionType, + RoundingMode, + SaturationMode, TensorVariable, UnspecifiedPrecisionType, WeightVariable, @@ -29,6 +33,9 @@ from hls4ml.utils import attribute_descriptions as descriptions from hls4ml.utils.string_utils import convert_to_snake_case +if typing.TYPE_CHECKING: + from hls4ml.model import ModelGraph + # TODO move this to some utility module class classproperty: @@ -80,7 +87,7 @@ def __init__(self, model, name, attributes, inputs, outputs=None): "No model layer should be named 'input' because that is a reserved;" + "layer name in ModelGraph; Please rename the layer in your model" ) - self.model = model + self.model: 'ModelGraph' = model self.name = name self.index = model.next_layer() self.inputs = inputs @@ -145,6 +152,9 @@ def _validate_attributes(self): # Validate existing attributes for attr_name, attr_value in self.attributes.items(): + if isinstance(attr_value, PrecisionType): + attr_value = self._wrap_precision_to_type(f'{self.name}_{attr_name}', attr_value) + self.set_attr(attr_name, attr_value) exp_attr = all_attributes.pop(attr_name, None) if exp_attr is not None: if not exp_attr.validate_value(attr_value): @@ -160,7 +170,7 @@ def _validate_attributes(self): for attr_name, attr in all_attributes.items(): if attr.default is not None: if isinstance(attr, TypeAttribute): - self.set_attr(attr_name, self._wrap_precision_to_type(self.name + '_' + attr_name, attr.default)) + self.set_attr(attr_name, self._wrap_precision_to_type(self.name + '_' + attr_name, copy(attr.default))) else: self.set_attr(attr_name, attr.default) else: @@ -910,7 +920,8 @@ def initialize(self): shape = inp.shape dims = inp.dim_names self.add_output_variable(shape, dims) - self.set_attr('n_in', self.get_input_variable().size()) + if 'n_in' not in self.attributes: + self.set_attr('n_in', self.get_input_variable().size()) class ParametrizedActivation(Activation): @@ -975,6 +986,31 @@ def initialize(self): class Softmax(Activation): + _expected_attributes = [ + Attribute('n_in'), + Attribute('activation', value_type=str), + Attribute('n_outer', value_type=int, default=1), + Attribute('n_inner', value_type=int, default=1), + ChoiceAttribute('implementation', ['latency', 'stable', 'argmax', 'legacy'], default='stable'), + ConfigurableAttribute('skip', value_type=bool, default=False), + TypeAttribute( + 'exp_table', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'inv_table', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'inv_inp', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'accum', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + ] + def initialize(self): super().initialize() diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h index 4683239d85..29acf0d7ab 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation.h @@ -130,37 +130,40 @@ enum class softmax_implementation { latency = 0, legacy = 1, stable = 2, argmax inline float exp_fcn_float(float input) { return std::exp(input); } -template inline float softmax_real_val_from_idx(unsigned i) { +template inline float softmax_real_val_from_idx(unsigned i) { // Treat the index as the top N bits - static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table + static constexpr int N = ceillog2(table_size); // number of address bits for table data_T x(0); x(x.width - 1, x.width - N) = i; return (float)x; } -template inline unsigned softmax_idx_from_real_val(data_T x) { +template inline unsigned softmax_idx_from_real_val(data_T x) { // Slice the top N bits to get an index into the table - static constexpr int N = ceillog2(CONFIG_T::table_size); // number of address bits for table - ap_uint y = x(x.width - 1, x.width - N); // slice the top N bits of input + static constexpr int N = ceillog2(table_size); // number of address bits for table + ap_uint y = x(x.width - 1, x.width - N); // slice the top N bits of input return (unsigned)y(N - 1, 0); } template -void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::table_size]) { +void init_exp_table(typename CONFIG_T::exp_table_t table_out[CONFIG_T::exp_table_size], bool negative = false) { // The template data_T is the data type used to address the table - for (unsigned i = 0; i < CONFIG_T::table_size; i++) { + for (unsigned i = 0; i < CONFIG_T::exp_table_size; i++) { // Slicing bits for address is going to round towards 0, so take the central value - float x = softmax_real_val_from_idx(i); + float x = softmax_real_val_from_idx(i) * CONFIG_T::exp_scale; + if (negative) { + x = -x; + } typename CONFIG_T::exp_table_t exp_x = exp_fcn_float(x); table_out[i] = exp_x; } } template -void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::table_size]) { +void init_invert_table(typename CONFIG_T::inv_table_t table_out[CONFIG_T::inv_table_size]) { // The template data_T is the data type used to address the table - for (unsigned i = 0; i < CONFIG_T::table_size; i++) { - float x = softmax_real_val_from_idx(i); + for (unsigned i = 0; i < CONFIG_T::inv_table_size; i++) { + float x = softmax_real_val_from_idx(i); typename CONFIG_T::inv_table_t inv_x = 1 / x; table_out[i] = inv_x; } @@ -172,40 +175,39 @@ void softmax_latency(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T init_exp_table(exp_table); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::accum_t exp_res[CONFIG_T::n_in]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - unsigned x = softmax_idx_from_real_val(data[i]); + unsigned x = softmax_idx_from_real_val(data[i]); exp_res[i] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll res[i] = exp_res[i] * inv_exp_sum; @@ -218,19 +220,19 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T - init_exp_table(exp_table); + init_exp_table(exp_table, true); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -238,31 +240,31 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { Op_max op_max; data_T x_max = reduce>(data, op_max); - // For the diffs, use the same type as the input but force rounding and saturation - ap_fixed d_xi_xmax[CONFIG_T::n_in]; + typename CONFIG_T::inp_norm_t d_xi_xmax[CONFIG_T::n_in]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - d_xi_xmax[i] = data[i] - x_max; + d_xi_xmax[i] = x_max - data[i]; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in]; + typename CONFIG_T::accum_t exp_res[CONFIG_T::n_in]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll - unsigned x = softmax_idx_from_real_val(d_xi_xmax[i]); + unsigned x = softmax_idx_from_real_val(d_xi_xmax[i]); exp_res[i] = exp_table[x]; + std::cout << "exp_res[" << i << "](" << d_xi_xmax[i].to_float() << "->" << x << ") = " << exp_res[i].to_float() + << std::endl; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; for (unsigned i = 0; i < CONFIG_T::n_in; i++) { #pragma HLS unroll res[i] = exp_res[i] * inv_exp_sum; @@ -299,16 +301,16 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Initialize the lookup table #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { - init_exp_table_legacy(exp_table); - init_invert_table_legacy(invert_table); + init_exp_table_legacy(exp_table); + init_invert_table_legacy(invert_table); initialized = true; } @@ -330,12 +332,12 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { if (ii == jj) exp_diff_res = 1; else { - data_round = (data_cache[jj] - data_cache[ii]) * CONFIG_T::table_size / 16; - index = data_round + 8 * CONFIG_T::table_size / 16; + data_round = (data_cache[jj] - data_cache[ii]) * CONFIG_T::exp_table_size / 16; + index = data_round + 8 * CONFIG_T::exp_table_size / 16; if (index < 0) index = 0; - if (index > CONFIG_T::table_size - 1) - index = CONFIG_T::table_size - 1; + if (index > CONFIG_T::exp_table_size - 1) + index = CONFIG_T::exp_table_size - 1; exp_diff_res = exp_table[index]; } exp_res[ii] += exp_diff_res; @@ -344,11 +346,11 @@ void softmax_legacy(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { // Second loop to invert for (int ii = 0; ii < CONFIG_T::n_in; ii++) { - int exp_res_index = exp_res[ii] * CONFIG_T::table_size / 64; + int exp_res_index = exp_res[ii] * CONFIG_T::inv_table_size / 64; if (exp_res_index < 0) exp_res_index = 0; - if (exp_res_index > CONFIG_T::table_size - 1) - exp_res_index = CONFIG_T::table_size - 1; + if (exp_res_index > CONFIG_T::inv_table_size - 1) + exp_res_index = CONFIG_T::inv_table_size - 1; // typename CONFIG_T::table_t exp_res_invert = invert_table[exp_res_index]; res[ii] = (res_T)invert_table[exp_res_index]; } @@ -394,6 +396,30 @@ void softmax(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]) { } } +template +void softmax_multidim(data_T data[CONFIG_T::n_outer * CONFIG_T::n_in * CONFIG_T::n_inner], + res_T res[CONFIG_T::n_outer * CONFIG_T::n_in * CONFIG_T::n_inner]) { + #pragma HLS inline + #pragma HLS allocation instances = softmax limit = CONFIG_T::parallelization_factor function + data_T buffer_in[CONFIG_T::n_in]; + res_T buffer_out[CONFIG_T::n_in]; + for (signed i = 0; i < CONFIG_T::n_outer; i++) { + #pragma HLS UNROLL + for (signed k = 0; k < CONFIG_T::n_inner; k++) { + #pragma HLS UNROLL + for (signed j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + buffer_in[j] = data[i * CONFIG_T::n_in * CONFIG_T::n_inner + j * CONFIG_T::n_inner + k]; + } + softmax(buffer_in, buffer_out); + for (signed j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + res[i * CONFIG_T::n_in * CONFIG_T::n_inner + j * CONFIG_T::n_inner + k] = buffer_out[j]; + } + } + } +} + // ************************************************* // TanH Activation // ************************************************* diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h index ef687243bf..00c61933f9 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_activation_stream.h @@ -109,19 +109,19 @@ void softmax_latency(hls::stream &data, hls::stream &res) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T init_exp_table(exp_table); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -129,9 +129,9 @@ void softmax_latency(hls::stream &data, hls::stream &res) { constexpr unsigned ii = data_T::size / multiplier_limit; // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[data_T::size]; + typename CONFIG_T::accum_t exp_res[data_T::size]; #pragma HLS array_partition variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); SoftmaxExpLoop: for (unsigned i = 0; i < CONFIG_T::n_in / data_T::size; i++) { #pragma HLS PIPELINE II=ii @@ -140,18 +140,17 @@ void softmax_latency(hls::stream &data, hls::stream &res) { SoftmaxExpPackLoop: for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - unsigned x = softmax_idx_from_real_val(in_pack[j]); + unsigned x = softmax_idx_from_real_val(in_pack[j]); exp_res[j] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + invert_table[softmax_idx_from_real_val(exp_sum)]; res_T out_pack; PRAGMA_DATA_PACK(out_pack) @@ -171,19 +170,19 @@ void softmax_stable(hls::stream &data, hls::stream &res) { // Initialize the lookup tables #ifdef __HLS_SYN__ bool initialized = false; - typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #else static bool initialized = false; - static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; - static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::table_size]; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::exp_table_size]; + static typename CONFIG_T::inv_table_t invert_table[CONFIG_T::inv_table_size]; #endif if (!initialized) { // Note we are exponentiating the inputs, which have type data_T - init_exp_table(exp_table); + init_exp_table(exp_table, true); // Note we are inverting the exponentials, which have type exp_table_t - init_invert_table(invert_table); + init_invert_table(invert_table); initialized = true; } @@ -208,31 +207,29 @@ void softmax_stable(hls::stream &data, hls::stream &res) { typename data_T::value_type x_max = reduce>(data_array, op_max); - // For the diffs, use the same type as the input but force rounding and saturation - ap_fixed d_xi_xmax[data_T::size]; + typename CONFIG_T::inp_norm_t d_xi_xmax[data_T::size]; for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - d_xi_xmax[j] = data_array[j] - x_max; + d_xi_xmax[j] = x_max - data_array[j]; } // Calculate all the e^x's - typename CONFIG_T::exp_table_t exp_res[data_T::size]; + typename CONFIG_T::accum_t exp_res[data_T::size]; #pragma HLS ARRAY_PARTITION variable=exp_res complete - typename CONFIG_T::exp_table_t exp_sum(0); + typename CONFIG_T::inv_inp_t exp_sum(0); for (unsigned j = 0; j < data_T::size; j++) { #pragma HLS UNROLL - unsigned x = softmax_idx_from_real_val(d_xi_xmax[j]); + unsigned x = softmax_idx_from_real_val(d_xi_xmax[j]); exp_res[j] = exp_table[x]; } // Explicitly sum the results with an adder tree. // Rounding & Saturation mode, which improve accuracy, prevent Vivado from expression balancing - Op_add op_add; - exp_sum = - reduce>(exp_res, op_add); + Op_add op_add; + exp_sum = reduce>(exp_res, op_add); - typename CONFIG_T::inv_table_t inv_exp_sum = - invert_table[softmax_idx_from_real_val(exp_sum)]; + typename CONFIG_T::accum_t inv_exp_sum = + invert_table[softmax_idx_from_real_val(exp_sum)]; res_T out_pack; PRAGMA_DATA_PACK(out_pack) diff --git a/test/pytest/test_conv1d.py b/test/pytest/test_conv1d.py index b58a35417a..2a00fb8d4a 100644 --- a/test/pytest/test_conv1d.py +++ b/test/pytest/test_conv1d.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from sklearn.metrics import accuracy_score from tensorflow.keras.models import model_from_json import hls4ml @@ -13,7 +12,7 @@ @pytest.fixture(scope='module') def data(): - X = np.random.rand(100, 10, 4) + X = np.random.rand(1000, 10, 4) return X @@ -110,6 +109,5 @@ def test_accuracy(data, keras_model, hls_model): y_hls4ml = hls_model.predict(X) # "Accuracy" of hls4ml predictions vs keras - rel_acc = accuracy_score(np.argmax(y_keras, axis=1), np.argmax(y_hls4ml, axis=1)) - print(f'hls4ml accuracy relative to keras: {rel_acc}') - assert rel_acc > 0.98 + mae = np.mean(np.abs(y_keras - y_hls4ml)) + assert mae < 9e-3