Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f88faa0
init: xls backend (not working), implemented xls specific layer infor…
Girjoaba Jul 16, 2025
4496252
test
Girjoaba Jul 16, 2025
d8b2415
feat: loading weights and creating infrastructure added to writer
Girjoaba Jul 17, 2025
fbe2e82
feat: init writer complete
Girjoaba Jul 17, 2025
f9b2863
feat: first end2end working test
Girjoaba Jul 18, 2025
0952f1b
fix: vector input support, change back to current directory
Girjoaba Jul 18, 2025
b32405b
refactoring: predict function call
Girjoaba Jul 24, 2025
6d91666
feat: solo relu activation test pass
Girjoaba Jul 24, 2025
86dd94a
debt cleanup: split dslx templates in multiple files
Girjoaba Jul 26, 2025
0b9ad57
refactoring: simplified writer -> attribute factory written as an opt…
Girjoaba Jul 28, 2025
a78bd1b
feat: softmax xls implementation of table lookup
Girjoaba Jul 31, 2025
b24e581
integrated strategies for the softmax implementation
Girjoaba Jul 31, 2025
039c514
bugfix: softmax latency implementation
Girjoaba Aug 1, 2025
dc8f5a9
cleanup: removed junk file
Girjoaba Aug 1, 2025
248c0f0
feat: stable softmax 1 specific precision working
Girjoaba Aug 1, 2025
4ff3f94
cleanup: removed junk
Girjoaba Aug 1, 2025
9a73968
feat: integrated stable softmax with all layers
Girjoaba Aug 1, 2025
5684d26
feat: softmax stable and argmax working any bit precision combination
Girjoaba Aug 3, 2025
ff02a02
feat: xls utilization report parsing with vivado
Girjoaba Aug 12, 2025
c6a9ccd
wip: cnn
Girjoaba Aug 19, 2025
81af6b6
feat: prepared writer weights for CNNs
Girjoaba Aug 21, 2025
5c42f5c
feat: conv2d_latency is now code generated
Girjoaba Aug 22, 2025
092a809
reverted look_up tables
Girjoaba Aug 28, 2025
90c3231
feat: timing report when building the project
Girjoaba Aug 31, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions docs/requirements.txt

This file was deleted.

2 changes: 2 additions & 0 deletions hls4ml/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hls4ml.backends.catapult.catapult_backend import CatapultBackend # isort: skip

from hls4ml.backends.vitis.vitis_backend import VitisBackend # isort: skip
from hls4ml.backends.xls.xls_backend import XLSBackend

register_backend('Vivado', VivadoBackend)
register_backend('VivadoAccelerator', VivadoAcceleratorBackend)
Expand All @@ -18,3 +19,4 @@
register_backend('Catapult', CatapultBackend)
register_backend('SymbolicExpression', SymbolicExpressionBackend)
register_backend('oneAPI', OneAPIBackend)
register_backend('XLS', XLSBackend)
12 changes: 9 additions & 3 deletions hls4ml/backends/backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Typing imports
from __future__ import annotations # makes all annotations into strings
from typing import List, Any, TYPE_CHECKING
if TYPE_CHECKING:
pass # Add typing classes here

from numpy.lib._iotools import str2bool
import inspect
import os
from pathlib import Path
Expand Down Expand Up @@ -56,7 +63,7 @@ def _get_layer_initializers(self):
def _get_layer_templates(self):
return [name for name in get_backend_passes(self.name) if isinstance(get_optimizer(name), Template)]

def create_initial_config(self, **kwargs):
def create_initial_config(self, **kwargs) -> dict[str, Any]:
"""Create the minimal conversion config for the backend.

Subclasses should implement this method to provide the initial configuration for the conversion.
Expand All @@ -82,7 +89,7 @@ def get_available_flows(self):
"""
return get_backend_flows(self.name)

def get_default_flow(self):
def get_default_flow(self) -> str:
"""The name of the default flow of the backend.

Default flow is used as the conversion target if the target flow has not been specified.
Expand Down Expand Up @@ -152,7 +159,6 @@ def register_template(self, template_cls):

backend_map = {}


def register_backend(name, backend_cls):
"""Create the backend instance and add it to the registry.

Expand Down
9 changes: 8 additions & 1 deletion hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Typing imports
from __future__ import annotations # makes all annotations into strings
from typing import List, Any, TYPE_CHECKING
if TYPE_CHECKING:
pass # Add typing classes here

import math
import re
import subprocess
Expand Down Expand Up @@ -187,6 +193,7 @@ def compile(self, model):

return lib_name


def write(self, model):
"""Write the generated project to disk.

Expand All @@ -199,7 +206,7 @@ def write(self, model):

model.apply_flow(self.get_writer_flow())

def get_writer_flow(self):
def get_writer_flow(self) -> str:
raise NotImplementedError

def get_layer_mult_size(self, layer):
Expand Down
Empty file added hls4ml/backends/xls/__init__.py
Empty file.
268 changes: 268 additions & 0 deletions hls4ml/backends/xls/passes/build_attr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
# Typing imports
from __future__ import annotations # makes all annotations into strings
from typing import List, Literal, Any, Optional, Callable, TYPE_CHECKING
from numpy.typing import NDArray
if TYPE_CHECKING:
from hls4ml.model.graph import ModelGraph
from hls4ml.model.layers import Layer


from hls4ml.model.optimizer import OptimizerPass

from functools import wraps
import numpy as np
from fxpmath import Fxp


class XLSAttrBuilder:
"""A helper class that sets XLS specific attributes for the layers of the original ModelGraph.
In doing so, we simplify the process of creating new optimization passes
and constructing the writer class.
The new attributes must be accessed with '.get_attr(...)'

New attributes:
write_weights (bool): the layer contains weights that should be explicitly defined in the project file
write_dims (bool): the layer dimensions should be explicitly written in the project file
write_func (bool): the layer has a corresponding function call that should be explicitly written
as part of the NN architecture in the project file
func_call (str): the corresponding layer DSLX function call

in_dim_key, out_dim_key (str): the variable name containing the layer dimensions (that goes in and out the layer)
in_dim_val, out_dim_val (int): the value of each layer dimension (that goes in and out the layer)

fxp_weights (np.ndarray): already quantized weight matrix
fxp_bias (np.ndarray): already quantized bias vector

in_nb, in_en, in_bu (str): parameters used for fixed point computation in DSLX
the parameters of the input vector
number of bits (width), is negative, binary unsigned exponent (frac bits)
out_nb, out_en, out_bu (str): parameters used for fixed point computation in DSLX
the parameters of the output vector
number of bits (width), is negative, binary unsigned exponent (frac bits)

Args:
node (Layer): A layer of the model graph
"""

def __init__(self, node) -> None:
self.node = node

@staticmethod
def attach_to_node(attr_name: Optional[str] = None) :
"""A decorator-factory to easily chain 'set_attr' commands to the node.
It calls the provided function. This eliminates a lot of boiler plate code.
All the added attributes can be chained in one call since the wrapped function returns self.
"""
def decorator(fn) -> Callable:
name = attr_name or fn.__name__
@wraps(fn)
def wrapped(self, *args, **kwargs):
val = fn(self, *args, **kwargs)
self.node.set_attr(name, val)
return self
return wrapped
return decorator

@attach_to_node()
def write_weights(self) -> bool:
return self.node.class_name in ['Dense', 'Conv2D']

@attach_to_node()
def write_dims(self) -> bool:
return self.node.class_name in ['Input', 'Dense', 'Conv2D']

@attach_to_node()
def write_func(self) -> bool:
return self.node.class_name in ['Dense', 'Activation', 'Softmax', 'Conv2D']


@attach_to_node()
def in_dim_key(self, k: str) -> str:
return k

@attach_to_node()
def in_dim_val(self, v: int) -> int:
return v

@attach_to_node()
def out_dim_key(self, k: str) -> str:
return k

@attach_to_node()
def out_dim_val(self, v: int) -> int:
return v

@attach_to_node()
def fxp_weights(self, weights, out_dim: int, in_dim: int) -> NDArray[NDArray[np.int_]]:
#TODO: check which element in the precision array should we take Currently we assume the precision of weights is the first elem.
# has weights
if len(weights) >= 1:
width = int(self.node.get_attr('in_nb').split(':', 1)[1])
frac = int(self.node.get_attr('in_bu').split(':', 1)[1])
# Conv
if self.node.class_name == 'Conv2D':
n_chan = self.node.get_attr('n_chan')
filt_height = self.node.get_attr('filt_height')
filt_width = self.node.get_attr('filt_width')
n_filt = self.node.get_attr('n_filt')
mat = np.array(list(list(weights)[0])).reshape(filt_height, filt_width, n_chan, n_filt)
mat_T = np.transpose(mat, (3, 2, 0, 1)) # in Keras the weights are transposed
fxp_w: NDArray[NDArray[np.int_]] = Fxp(mat_T, signed=True, n_word=width, n_frac=frac).raw()
return fxp_w

# Dense
elif self.node.class_name == 'Dense':
mat = np.array(list(list(weights)[0])).reshape(in_dim, out_dim)
mat_T = mat.T # in Keras the weights are transposed
fxp_w: NDArray[NDArray[np.int_]] = Fxp(mat_T, signed=True, n_word=width, n_frac=frac).raw()
return fxp_w
return np.array([])

@attach_to_node()
def fxp_bias(self, weights) -> NDArray[np.int_]:
#TODO: check which element in the precision array should we take Currently we assume the precision of weights is the first elem.
# has bias
if len(weights) >= 2:
width = int(self.node.get_attr('in_nb').split(':', 1)[1])
frac = int(self.node.get_attr('in_bu').split(':', 1)[1])
fxp_b: NDArray[np.int_] = Fxp(list(list(weights)[1]), signed=True, n_word=width, n_frac=frac).raw()
return fxp_b
return np.array([])

@attach_to_node()
def in_nb(self, prev_layer_precision: dict | None) -> str: # TODO: right now we only care about the first defined type in the list
if prev_layer_precision:
for _, type_var in prev_layer_precision.items():
return f'u32:{type_var.precision.width}'
return ''

@attach_to_node()
def in_en(self) -> Literal['u32:1']:
return 'u32:1'

@attach_to_node()
def in_bu(self, prev_layer_precision: dict | None) -> str:
if prev_layer_precision:
for _, type_var in prev_layer_precision.items():
return f'u32:{type_var.precision.width - type_var.precision.integer}'
return ''

@attach_to_node()
def out_nb(self, layer_precision: dict) -> str:
if layer_precision.get('result_t', False):
width = layer_precision['result_t'].precision.width
return f'u32:{width}'
for _, type_var in layer_precision.items():
return f'u32:{type_var.precision.width}'
return ''

@attach_to_node()
def out_en(self) -> Literal['u32:1']:
return 'u32:1'

@attach_to_node()
def out_bu(self, layer_precision) -> str:
if layer_precision.get('result_t', False):
width = layer_precision['result_t'].precision.width
integer = layer_precision['result_t'].precision.integer
return f'u32:{width - integer}'
for _, type_var in layer_precision.items():
return f'u32:{type_var.precision.width - type_var.precision.integer}'
return ''

@attach_to_node()
def in_type(self) -> str:
return f'sN[{self.node.get_attr("in_nb")}]'

@attach_to_node()
def out_type(self) -> str:
return f'sN[{self.node.get_attr("out_nb")}]'

@attach_to_node()
def func_call(self) -> str:
func_call_str = ''
if self.node.class_name == 'Dense':
func_call_str = f'fc::dense<{self.node.get_attr("in_nb")}, {self.node.get_attr("in_en")}, {self.node.get_attr("in_bu")}, {self.node.get_attr("out_nb")}, {self.node.get_attr("out_en")}, {self.node.get_attr("out_bu")}>'

elif self.node.class_name == 'Conv2D':
func_call_str = f'conv2d::conv2d_latency<{self.node.get_attr("in_nb")}, {self.node.get_attr("in_en")}, {self.node.get_attr("in_bu")}, {self.node.get_attr("out_nb")}, {self.node.get_attr("out_en")}, {self.node.get_attr("out_bu")}>'

elif self.node.class_name == 'Activation':
func_call_str = f'activations::relu<{self.node.get_attr("out_nb")}>'

elif self.node.class_name == 'Softmax':
implementation = dict(self.node.attributes).get('implementation', 'stable')
if implementation == 'stable':
table_size = dict(self.node.attributes)['table_size']
exp_width = self.node.get_layer_precision()['softmax_exp_table_t'].precision.width
exp_frac = exp_width - self.node.get_layer_precision()['softmax_exp_table_t'].precision.integer
inv_width = self.node.get_layer_precision()['softmax_inv_table_t'].precision.width
inv_frac = inv_width - self.node.get_layer_precision()['softmax_inv_table_t'].precision.integer

func_call_str = (
f"lookup_tables::softmax_stable<"
f"{self.node.get_attr('in_nb')}, {self.node.get_attr('in_en')}, {self.node.get_attr('in_bu')}, "
f" {self.node.get_attr('out_nb')}, {self.node.get_attr('out_en')}, {self.node.get_attr('out_bu')}, "
f"u32:{exp_width}, u32:1, u32:{exp_frac}, "
f"u32:{inv_width}, u32:1, u32:{inv_frac}, "
f"u32:{table_size}>"
)
elif implementation == 'latency':
table_size = dict(self.node.attributes)['table_size']
func_call_str = f'lookup_tables::softmax_latency<{self.node.get_attr("in_nb")}, {self.node.get_attr("in_en")}, {self.node.get_attr("in_bu")}, {self.node.get_attr("out_nb")}, {self.node.get_attr("out_en")}, {self.node.get_attr("out_bu")}, u32:{table_size}>'
elif implementation == 'argmax':
func_call_str = f'activations::argmax<{self.node.get_attr("in_nb")}, {self.node.get_attr("in_en")}, {self.node.get_attr("in_bu")}, {self.node.get_attr("out_nb")}, {self.node.get_attr("out_en")}, {self.node.get_attr("out_bu")}>'
return func_call_str


class BuildAttr(OptimizerPass):
"""Builds the XLS specific attributes for all layers.
"""

def match(self, node: Layer) -> bool:
if node.class_name == 'Input':
return True
return False

def transform(self, model: ModelGraph, node: Layer) -> Literal[False]:
prev_out_dim_key = ''
prev_out_dim_val = -1
prev_layer_precision = None

for layer in model.get_layers():
curr_out_dim_key: str = list(layer.get_output_variable().get_shape())[0][0]
curr_out_dim_val: int = list(layer.get_output_variable().get_shape())[0][1]

curr_weights = layer.get_weights()
curr_prec: dict = layer.get_layer_precision()

# uses the builder to add all the attributes
b = XLSAttrBuilder(layer)
(b
.write_dims()
.write_weights()
.write_func()
.in_dim_key(prev_out_dim_key)
.in_dim_val(prev_out_dim_val)
.out_dim_key(curr_out_dim_key)
.out_dim_val(curr_out_dim_val)
.in_nb(prev_layer_precision)
.in_en()
.in_bu(prev_layer_precision)
.out_nb(curr_prec)
.out_en()
.out_bu(curr_prec)
.in_type()
.out_type()
.fxp_weights(curr_weights, out_dim=curr_out_dim_val, in_dim=prev_out_dim_val)
.fxp_bias(curr_weights)
.func_call()

)

prev_out_dim_key = curr_out_dim_key
prev_out_dim_val = curr_out_dim_val
prev_layer_precision = curr_prec

return False

Loading