forked from YdrMaster/operators
-
Notifications
You must be signed in to change notification settings - Fork 24
Add Batch Norm #111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Ziminli
wants to merge
12
commits into
dev
Choose a base branch
from
add_batch_norm
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add Batch Norm #111
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
4fdd9b1
Add batch norm CPU and CUDA implementation
Ziminli 25473cc
Remove unnecessary testing utils in the frontend test
Ziminli 4f15232
Add inplace testing
Ziminli 6120325
Rebase to the latest dev
Ziminli 1526ce5
Remove redundant include directory
Ziminli a4c323a
Simplify ndim and shape checks
Ziminli 9cd7d0b
Correct test assertion condition
Ziminli e60bc7b
add newline after infini_operators.h
Ziminli 9205475
Changed test structure
Ziminli 8b4b50b
Loose fp32 test restriction
Ziminli 80e8397
Fix misc., optimize code structure
Ziminli a5f29f0
Correct ascend to npu in device_enum_to_str
Ziminli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| #ifndef BATCH_NORM_H | ||
| #define BATCH_NORM_H | ||
|
|
||
| #include "../../export.h" | ||
| #include "../../operators.h" | ||
|
|
||
| typedef struct BatchNormDescriptor { | ||
| Device device; | ||
| } BatchNormDescriptor; | ||
|
|
||
| typedef BatchNormDescriptor *infiniopBatchNormDescriptor_t; | ||
|
|
||
| __C __export infiniopStatus_t infiniopCreateBatchNormDescriptor(infiniopHandle_t handle, | ||
| infiniopBatchNormDescriptor_t *desc_ptr, | ||
| infiniopTensorDescriptor_t y, | ||
| infiniopTensorDescriptor_t x, | ||
| infiniopTensorDescriptor_t scale, | ||
| infiniopTensorDescriptor_t b, | ||
| infiniopTensorDescriptor_t mean, | ||
| infiniopTensorDescriptor_t var, | ||
| double eps); | ||
|
|
||
| __C __export infiniopStatus_t infiniopBatchNorm(infiniopBatchNormDescriptor_t desc, | ||
| void *y, | ||
| void const *x, | ||
| void const *scale, | ||
| void const *b, | ||
| void const *mean, | ||
| void const *var, | ||
| void *stream); | ||
|
|
||
| __C __export infiniopStatus_t infiniopDestroyBatchNormDescriptor(infiniopBatchNormDescriptor_t desc); | ||
|
|
||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| from ctypes import POINTER, Structure, c_int32, c_void_p, c_double | ||
| import ctypes | ||
| import sys | ||
| import os | ||
| import time | ||
|
|
||
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) | ||
| from operatorspy import ( | ||
| open_lib, | ||
| to_tensor, | ||
| DeviceEnum, | ||
| infiniopHandle_t, | ||
| infiniopTensorDescriptor_t, | ||
| create_handle, | ||
| destroy_handle, | ||
| check_error, | ||
| ) | ||
|
|
||
| from operatorspy.tests.test_utils import get_args | ||
| from enum import Enum, auto | ||
| import torch | ||
| import ctypes | ||
| import torch.nn.functional as F | ||
|
|
||
| # constant for control whether profile the pytorch and lib functions | ||
| # NOTE: need to manually add synchronization function to the lib function, | ||
| # e.g., cudaDeviceSynchronize() for CUDA | ||
| PROFILE = False | ||
| NUM_PRERUN = 10 | ||
| NUM_ITERATIONS = 1000 | ||
|
|
||
| class Inplace(Enum): | ||
| OUT_OF_PLACE = auto() | ||
| INPLACE_X = auto() | ||
|
|
||
|
|
||
| class BatchNormDescriptor(Structure): | ||
| _fields_ = [("device", c_int32)] | ||
|
|
||
|
|
||
| infiniopBatchNormDescriptor_t = POINTER(BatchNormDescriptor) | ||
|
|
||
|
|
||
| def batch_norm(x, scale, b, mean, var, eps): | ||
| ndim = len(x.shape) | ||
| if ndim <= 1 or ndim > 5: | ||
| print("Error: Pytorch -> Unsupported tensor dimension") | ||
| return None | ||
| if PROFILE: | ||
| ans = F.batch_norm(x, mean, var, scale, b, training=False, eps=eps) | ||
| torch.cuda.synchronize() | ||
| return ans | ||
| return F.batch_norm(x, mean, var, scale, b, training=False, eps=eps) | ||
|
|
||
|
|
||
| # get the mean and variance of the input tensor across the batch size N and spatial dimensions | ||
| def get_mean_variance(x, dtype): | ||
| dims = tuple(range(x.ndim)) | ||
| reduction_dims = tuple(d for d in dims if d != 1) # Exclude the channel dimension | ||
| return x.mean(dim=reduction_dims, dtype=dtype), x.var( | ||
| dim=reduction_dims, unbiased=False | ||
| ).to(dtype) | ||
|
|
||
|
|
||
| def test( | ||
| lib, | ||
| handle, | ||
| torch_device, | ||
| x_shape, | ||
| eps=1e-5, | ||
| tensor_dtype=torch.float16, | ||
| inplace=Inplace.OUT_OF_PLACE, | ||
| ): | ||
| print( | ||
| f"Testing BatchNorm on {torch_device} with x_shape: {x_shape}, scale_shape: {x_shape[1]}, b_shape: {x_shape[1]}, mean_shape: {x_shape[1]}, var_shape: {x_shape[1]}, eps: {eps}, dtype:{tensor_dtype}, Inplace:{inplace}" | ||
| ) | ||
| num_channel = x_shape[1] | ||
| bn_dtype = tensor_dtype if tensor_dtype != torch.float16 else torch.float32 | ||
| x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) * 10 - 2 | ||
| scale = torch.rand(num_channel, dtype=bn_dtype).to(torch_device) | ||
| b = torch.rand(num_channel, dtype=bn_dtype).to(torch_device) | ||
| mean, var = get_mean_variance(x, bn_dtype) | ||
| y = torch.zeros(x_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else x | ||
|
|
||
| # get the pytorch answer | ||
| for i in range(NUM_PRERUN if PROFILE else 1): | ||
| ans = batch_norm(x, scale, b, mean, var, eps) | ||
| if PROFILE: | ||
| start_time = time.time() | ||
| for i in range(NUM_ITERATIONS): | ||
| _ = batch_norm(x, scale, b, mean, var, eps) | ||
| elapsed = (time.time() - start_time) / NUM_ITERATIONS | ||
| print(f"pytorch time: {elapsed :6f}") | ||
|
|
||
| # get the operators' answer | ||
| x_tensor = to_tensor(x, lib) | ||
| scale_tensor = to_tensor(scale, lib) | ||
| b_tensor = to_tensor(b, lib) | ||
| mean_tensor = to_tensor(mean, lib) | ||
| var_tensor = to_tensor(var, lib) | ||
| y_tensor = to_tensor(y, lib) if inplace == Inplace.OUT_OF_PLACE else x_tensor | ||
| descriptor = infiniopBatchNormDescriptor_t() | ||
|
|
||
| check_error( | ||
| lib.infiniopCreateBatchNormDescriptor( | ||
| handle, | ||
| ctypes.byref(descriptor), | ||
| y_tensor.descriptor, | ||
| x_tensor.descriptor, | ||
| scale_tensor.descriptor, | ||
| b_tensor.descriptor, | ||
| mean_tensor.descriptor, | ||
| var_tensor.descriptor, | ||
| eps, | ||
| ) | ||
| ) | ||
|
|
||
| for i in range(NUM_PRERUN if PROFILE else 1): | ||
| check_error( | ||
| lib.infiniopBatchNorm( | ||
| descriptor, | ||
| y_tensor.data, | ||
| x_tensor.data, | ||
| scale_tensor.data, | ||
| b_tensor.data, | ||
| mean_tensor.data, | ||
| var_tensor.data, | ||
| None, | ||
| ) | ||
| ) | ||
| if PROFILE: | ||
| start_time = time.time() | ||
| for i in range(NUM_ITERATIONS): | ||
| lib.infiniopBatchNorm( | ||
| descriptor, | ||
| y_tensor.data, | ||
| x_tensor.data, | ||
| scale_tensor.data, | ||
| b_tensor.data, | ||
| mean_tensor.data, | ||
| var_tensor.data, | ||
| None, | ||
| ) | ||
| elapsed = (time.time() - start_time) / NUM_ITERATIONS | ||
| print(f" lib time: {elapsed :6f}") | ||
|
|
||
| if (tensor_dtype == torch.float16): | ||
| assert torch.allclose(y, ans, atol=1e-5, rtol=1e-3) | ||
| else: # float32 | ||
| assert torch.allclose(y, ans, atol=1e-7, rtol=1e-3) | ||
| check_error(lib.infiniopDestroyBatchNormDescriptor(descriptor)) | ||
|
|
||
|
|
||
| def test_cpu(lib, test_cases): | ||
| device = DeviceEnum.DEVICE_CPU | ||
| handle = create_handle(lib, device) | ||
| for x_shape, eps, inplace in test_cases: | ||
| test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float16, inplace=inplace) | ||
| test(lib, handle, "cpu", x_shape, eps, tensor_dtype=torch.float32, inplace=inplace) | ||
| destroy_handle(lib, handle) | ||
|
|
||
|
|
||
| def test_cuda(lib, test_cases): | ||
| device = DeviceEnum.DEVICE_CUDA | ||
| handle = create_handle(lib, device) | ||
| for x_shape, eps, inplace in test_cases: | ||
| test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float16, inplace=inplace) | ||
| test(lib, handle, "cuda", x_shape, eps, tensor_dtype=torch.float32, inplace=inplace) | ||
| destroy_handle(lib, handle) | ||
|
|
||
|
|
||
| def test_bang(lib, test_cases): | ||
| import torch_mlu | ||
|
|
||
| device = DeviceEnum.DEVICE_BANG | ||
| handle = create_handle(lib, device) | ||
| for x_shape, eps, inplace in test_cases: | ||
| test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float16, inplace=inplace) | ||
| test(lib, handle, "mlu", x_shape, eps, tensor_dtype=torch.float32, inplace=inplace) | ||
| destroy_handle(lib, handle) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_cases = [ | ||
| # x_shape, eps, inplace | ||
| ((2, 5, 7), 1e-5, Inplace.OUT_OF_PLACE), | ||
| # ((2, 5, 7), 1e-5, Inplace.INPLACE_X), | ||
| ((32, 3, 1024), 1e-5, Inplace.OUT_OF_PLACE), | ||
| ((32, 3, 128, 128), 1e-5, Inplace.OUT_OF_PLACE), | ||
| ((32, 3, 64, 64, 64), 1e-5, Inplace.OUT_OF_PLACE), | ||
| ] | ||
| args = get_args() | ||
| lib = open_lib() | ||
| lib.infiniopCreateBatchNormDescriptor.restype = c_int32 | ||
| lib.infiniopCreateBatchNormDescriptor.argtypes = [ | ||
| infiniopHandle_t, | ||
| POINTER(infiniopBatchNormDescriptor_t), | ||
| infiniopTensorDescriptor_t, | ||
| infiniopTensorDescriptor_t, | ||
| infiniopTensorDescriptor_t, | ||
| infiniopTensorDescriptor_t, | ||
| infiniopTensorDescriptor_t, | ||
| infiniopTensorDescriptor_t, | ||
| c_double, | ||
| ] | ||
| lib.infiniopBatchNorm.restype = c_int32 | ||
| lib.infiniopBatchNorm.argtypes = [ | ||
| infiniopBatchNormDescriptor_t, | ||
| c_void_p, | ||
| c_void_p, | ||
| c_void_p, | ||
| c_void_p, | ||
| c_void_p, | ||
| c_void_p, | ||
| c_void_p, | ||
| ] | ||
| lib.infiniopDestroyBatchNormDescriptor.restype = c_int32 | ||
| lib.infiniopDestroyBatchNormDescriptor.argtypes = [ | ||
| infiniopBatchNormDescriptor_t, | ||
| ] | ||
|
|
||
| if args.cpu: | ||
| test_cpu(lib, test_cases) | ||
| if args.cuda: | ||
| test_cuda(lib, test_cases) | ||
| if args.bang: | ||
| test_bang(lib, test_cases) | ||
| if not (args.cpu or args.cuda or args.bang): | ||
| test_cpu(lib, test_cases) | ||
| print("\033[92mTest passed!\033[0m") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| #include "batch_norm_cpu.h" | ||
| #include "../../../devices/cpu/common_cpu.h" | ||
| #include "../../utils.h" | ||
|
|
||
| infiniopStatus_t cpuCreateBatchNormDescriptor(infiniopHandle_t, | ||
| BatchNormCpuDescriptor_t *desc_ptr, | ||
| infiniopTensorDescriptor_t y, | ||
| infiniopTensorDescriptor_t x, | ||
| infiniopTensorDescriptor_t scale, | ||
| infiniopTensorDescriptor_t b, | ||
| infiniopTensorDescriptor_t mean, | ||
| infiniopTensorDescriptor_t var, | ||
| double eps) { | ||
| uint64_t ndim = y->ndim; | ||
| if (ndim != x->ndim || scale->ndim != b->ndim || scale->ndim != mean->ndim || scale->ndim != var->ndim || scale->ndim != 1) { | ||
| return STATUS_BAD_TENSOR_SHAPE; | ||
| } | ||
| for (size_t i = 0; i < ndim; ++i) { | ||
| if (y->shape[i] != x->shape[i]) { | ||
| return STATUS_BAD_TENSOR_SHAPE; | ||
| } | ||
| } | ||
| if (x->shape[1] != scale->shape[0] || scale->shape[0] != b->shape[0] || scale->shape[0] != mean->shape[0] || scale->shape[0] != var->shape[0]) { | ||
| return STATUS_BAD_TENSOR_SHAPE; | ||
| } | ||
| if (!is_contiguous(y) || !is_contiguous(x)) { | ||
| return STATUS_BAD_TENSOR_STRIDES; | ||
| } | ||
| if (y->dt != F16 && y->dt != F32) { | ||
| return STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
| if (y->dt != x->dt) { | ||
| return STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
| if (eps < 0) { | ||
| return STATUS_BAD_PARAM; | ||
| } | ||
|
|
||
| uint64_t spatial_data_size = std::accumulate(x->shape + 2, x->shape + x->ndim, 1ULL, std::multiplies<uint64_t>()); | ||
| uint64_t batch_size = x->shape[0]; | ||
| uint64_t channel_size = x->shape[1]; | ||
|
|
||
| *desc_ptr = new BatchNormCpuDescriptor{ | ||
| DevCpu, | ||
| y->dt, | ||
| batch_size, | ||
| channel_size, | ||
| spatial_data_size, | ||
| channel_size * spatial_data_size, | ||
| eps, | ||
| }; | ||
|
|
||
| return STATUS_SUCCESS; | ||
| } | ||
|
|
||
| infiniopStatus_t cpuDestroyBatchNormDescriptor(BatchNormCpuDescriptor_t desc) { | ||
| delete desc; | ||
| return STATUS_SUCCESS; | ||
| } | ||
|
|
||
| template<typename Tdata, typename Pdata> | ||
| infiniopStatus_t batch_norm_cpu(BatchNormCpuDescriptor_t desc, void *y, void const *x, | ||
| void const *scale, void const *b, void const *mean, void const *var) { | ||
| auto x_ = reinterpret_cast<Tdata const *>(x); | ||
| auto scale_ = reinterpret_cast<Pdata const *>(scale); | ||
| auto b_ = reinterpret_cast<Pdata const *>(b); | ||
| auto mean_ = reinterpret_cast<Pdata const *>(mean); | ||
| auto var_ = reinterpret_cast<Pdata const *>(var); | ||
| auto y_ = reinterpret_cast<Tdata *>(y); | ||
|
|
||
| #pragma omp parallel for collapse(3) | ||
| for (uint64_t i = 0; i < desc->batch_size; ++i) { | ||
| for (uint64_t c = 0; c < desc->channel_size; ++c) { | ||
| for (uint64_t j = 0; j < desc->spatial_data_size; ++j) { | ||
| auto idx = (i * desc->channel_size + c) * desc->spatial_data_size + j; | ||
| Pdata invsqrt = 1 / std::sqrt(var_[c] + desc->eps); | ||
| if constexpr (std::is_same<Tdata, uint16_t>::value) { | ||
| y_[idx] = f32_to_f16((f16_to_f32(x_[idx]) - mean_[c]) * invsqrt * scale_[c] + b_[c]); | ||
| } else { | ||
| y_[idx] = (x_[idx] - mean_[c]) * invsqrt * scale_[c] + b_[c]; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return STATUS_SUCCESS; | ||
| } | ||
|
|
||
| infiniopStatus_t cpuBatchNorm(BatchNormCpuDescriptor_t desc, | ||
| void *y, void const *x, void const *scale, void const *b, | ||
| void const *mean, void const *var, void *stream) { | ||
| if (desc->dtype == F16) { | ||
| return batch_norm_cpu<uint16_t, float>(desc, y, x, scale, b, mean, var); | ||
| } | ||
| if (desc->dtype == F32) { | ||
| return batch_norm_cpu<float, float>(desc, y, x, scale, b, mean, var); | ||
| } | ||
| return STATUS_BAD_TENSOR_DTYPE; | ||
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.