-
Notifications
You must be signed in to change notification settings - Fork 13k
Add conv2d Implicit GEMM #15805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add conv2d Implicit GEMM #15805
Conversation
…es for 2D convolution
Why are you adding a new ggml op? |
Because of #15669 (comment) |
I think the implementation of implicit gemm can directly use ggml_conv2d_direct. There's really no need to provide so many conv2d functions. |
I can reuse ggml_conv2d_direct. TBH it is not a very good or intuitive name (the best one, ggml_conv_2d, is already occupied). I do wish it has an additional argument (ggml_conv_2d should carry in the beginning) for what method implemented. |
If the performance of implicit gemm is on par with or even better than that of im2col + gemm, I think ggml_conv_2d can also adopt the implementation of implicit gemm. |
What should be done regarding For kernel selection, please take a look at how e.g. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this PR, try removing the current conv2d kernel and replacing it with this one. Chances are it will be universally faster since it uses shared memory and has (unless I misread the code) coalesced memory accesses. I'll test the performance using a P40, RTX 3090, and RTX 4090 for NVIDIA and an RX 6800 and Mi 50 for AMD.
#include "convert.cuh" | ||
|
||
typedef struct{ | ||
unsigned int n; //batch szie |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsigned int n; //batch szie | |
unsigned int n; //batch size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
typedef struct{ | ||
unsigned int n; //batch szie | ||
unsigned int c; //channel number |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to either "channel index" or "number of channels" depending on which this is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
int threadz = 1; // threadz number per block | ||
dim3 thblock(threadx, thready, threadz); | ||
dim3 grid(blockx, blocky, blockz); | ||
int smem_size = 24 * 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On some CUDA architectures shared memory comes out of the L1 cache so it at all possible you should reserve only as much as will actually be used.
float * __restrict__ output, | ||
const param_t param) { | ||
|
||
extern __shared__ __align__(16 * 1024) char smem[]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of __align__
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed, no difference in performance
for (int i = 0; i < 4; ++i) | ||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (int i = 0; i < 4; ++i) | |
{ | |
for (int i = 0; i < 4; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. corrected styles in all places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @JohannesGaessler, for taking time to review. I agree with your idea as to kernel selection behind the scenes. Indeed, no single kernel is optimal for input and filter shapes. That's why cudnn provide all kinds of them for user to choose. Previously I am not sure if selecting kernels is possible and I 'll look into FLASH_ATTN_EXT example (thanks again).
Now #15813 is adding tensor support with shared mem, I don't want to step over. This PR will be in hold for now. I may contribute to the current conv_2d_direct once tensor code is merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if there is a kernel with tensor core support a good kernel without tensor cores would still be extremely useful. P40s and Mi50s are very cheap options for 24/32 GB VRAM but they lack tensor cores. And from a ggml perspective it's much easier to squeeze out more performance than it is to compress the weights (without affecting quality).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Speaking of P40s, you should be careful with FP16 arithmetic since that is massively gimped on Pascal. You can use the macro FAST_FP16_AVAILABLE
to check whether FP16 would be fast and use FP32 as a workaround if not. You can look at e.g. mmvf.cu
for an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into it. Thanks.
…ncy; update parameter comments and remove unused code
…test for implicit convolution
This PR added another CUDA conv_2d op using implicit GEMM approach. It is only optimized for cuda cores and its performance is up to 10x of that of direct method currently in llama.cpp.
On a RTX4090
Comparison with im2col+gemm
Fp16 filter, Fp32 activation
Fp32 filter, Fp32 activation