-
-
Notifications
You must be signed in to change notification settings - Fork 67
Implement EfficientNet #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 13 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b132c59
Add initial efficient net
darsnack 50ca201
Clean up implementation and add docstrings
darsnack 0ccfc78
Add tests and docs
darsnack e1dee6c
Don't forget to export!
darsnack d7e2e19
Also don't forget to include
darsnack a83efd9
No need to use image resolution scaling (implicitly done)
darsnack 5e5a56e
Use acctest for efficient net
darsnack 7e2e1e4
Fix a copy paste error for _round_channels
darsnack 19a18bf
Splat conv_bn
darsnack 61d886b
Fix size error in test setup
darsnack aa1d37d
Account for width scaling in input channel size
darsnack 01ab049
Only test smaller variants on efficient net
darsnack ea1b92e
Adjust testing and use `Chain(::Vector)`
darsnack b9e238c
Update test/convnets.jl
darsnack 7a32bb0
Adjust ConvNeXt and ConvMixer
darsnack a3f44c8
Test less for EfficientNet too
darsnack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| """ | ||
| efficientnet(scalings, block_config; | ||
| inchannels = 3, nclasses = 1000, max_width = 1280) | ||
|
|
||
| Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). | ||
|
|
||
| # Arguments | ||
|
|
||
| - `scalings`: global width and depth scaling (given as a tuple) | ||
| - `block_config`: configuration for each inverted residual block, | ||
| given as a vector of tuples with elements: | ||
| - `n`: number of block repetitions (will be scaled by global depth scaling) | ||
| - `k`: kernel size | ||
| - `s`: kernel stride | ||
| - `e`: expansion ratio | ||
| - `i`: block input channels | ||
| - `o`: block output channels (will be scaled by global width scaling) | ||
| - `inchannels`: number of input channels | ||
| - `nclasses`: number of output classes | ||
| - `max_width`: maximum number of output channels before the fully connected | ||
| classification blocks | ||
| """ | ||
| function efficientnet(scalings, block_config; | ||
| inchannels = 3, nclasses = 1000, max_width = 1280) | ||
| wscale, dscale = scalings | ||
| scalew(w) = wscale ≈ 1 ? w : ceil(Int64, wscale * w) | ||
| scaled(d) = dscale ≈ 1 ? d : ceil(Int64, dscale * d) | ||
|
|
||
| out_channels = _round_channels(scalew(32), 8) | ||
| stem = conv_bn((3, 3), inchannels, out_channels, swish; | ||
| bias = false, stride = 2, pad = SamePad()) | ||
|
|
||
| blocks = [] | ||
| for (n, k, s, e, i, o) in block_config | ||
| in_channels = _round_channels(scalew(i), 8) | ||
| out_channels = _round_channels(scalew(o), 8) | ||
| repeats = scaled(n) | ||
|
|
||
| push!(blocks, | ||
| invertedresidual(k, in_channels, in_channels * e, out_channels, swish; | ||
| stride = s, reduction = 4)) | ||
| for _ in 1:(repeats - 1) | ||
| push!(blocks, | ||
| invertedresidual(k, out_channels, out_channels * e, out_channels, swish; | ||
| stride = 1, reduction = 4)) | ||
| end | ||
| end | ||
| blocks = Chain(blocks...) | ||
|
|
||
| head_out_channels = _round_channels(max_width, 8) | ||
| head = conv_bn((1, 1), out_channels, head_out_channels, swish; | ||
| bias = false, pad = SamePad()) | ||
|
|
||
| top = Dense(head_out_channels, nclasses) | ||
|
|
||
| return Chain(Chain([stem..., blocks, head...]), | ||
| Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, top)) | ||
| end | ||
|
|
||
| # n: # of block repetitions | ||
| # k: kernel size k x k | ||
| # s: stride | ||
| # e: expantion ratio | ||
| # i: block input channels | ||
| # o: block output channels | ||
| const efficientnet_block_configs = [ | ||
| # (n, k, s, e, i, o) | ||
| (1, 3, 1, 1, 32, 16), | ||
| (2, 3, 2, 6, 16, 24), | ||
| (2, 5, 2, 6, 24, 40), | ||
| (3, 3, 2, 6, 40, 80), | ||
| (3, 5, 1, 6, 80, 112), | ||
| (4, 5, 2, 6, 112, 192), | ||
| (1, 3, 1, 6, 192, 320) | ||
| ] | ||
|
|
||
| # w: width scaling | ||
| # d: depth scaling | ||
| # r: image resolution | ||
| const efficientnet_global_configs = Dict( | ||
| # ( r, ( w, d)) | ||
| :b0 => (224, (1.0, 1.0)), | ||
| :b1 => (240, (1.0, 1.1)), | ||
| :b2 => (260, (1.1, 1.2)), | ||
| :b3 => (300, (1.2, 1.4)), | ||
| :b4 => (380, (1.4, 1.8)), | ||
| :b5 => (456, (1.6, 2.2)), | ||
| :b6 => (528, (1.8, 2.6)), | ||
| :b7 => (600, (2.0, 3.1)), | ||
| :b8 => (672, (2.2, 3.6)) | ||
| ) | ||
|
|
||
| struct EfficientNet | ||
| layers::Any | ||
| end | ||
|
|
||
| """ | ||
| EfficientNet(scalings, block_config; | ||
| inchannels = 3, nclasses = 1000, max_width = 1280) | ||
|
|
||
| Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). | ||
| See also [`efficientnet`](#). | ||
|
|
||
| # Arguments | ||
|
|
||
| - `scalings`: global width and depth scaling (given as a tuple) | ||
| - `block_config`: configuration for each inverted residual block, | ||
| given as a vector of tuples with elements: | ||
| - `n`: number of block repetitions (will be scaled by global depth scaling) | ||
| - `k`: kernel size | ||
| - `s`: kernel stride | ||
| - `e`: expansion ratio | ||
| - `i`: block input channels | ||
| - `o`: block output channels (will be scaled by global width scaling) | ||
| - `inchannels`: number of input channels | ||
| - `nclasses`: number of output classes | ||
| - `max_width`: maximum number of output channels before the fully connected | ||
| classification blocks | ||
| """ | ||
| function EfficientNet(scalings, block_config; | ||
| inchannels = 3, nclasses = 1000, max_width = 1280) | ||
| layers = efficientnet(scalings, block_config; | ||
| inchannels = inchannels, | ||
| nclasses = nclasses, | ||
| max_width = max_width) | ||
| return EfficientNet(layers) | ||
| end | ||
|
|
||
| @functor EfficientNet | ||
|
|
||
| (m::EfficientNet)(x) = m.layers(x) | ||
|
|
||
| backbone(m::EfficientNet) = m.layers[1] | ||
| classifier(m::EfficientNet) = m.layers[2] | ||
|
|
||
| """ | ||
| EfficientNet(name::Symbol; pretrain = false) | ||
|
|
||
| Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). | ||
| See also [`efficientnet`](#). | ||
|
|
||
| # Arguments | ||
|
|
||
| - `name`: name of default configuration | ||
| (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) | ||
| - `pretrain`: set to `true` to load the pre-trained weights for ImageNet | ||
| """ | ||
| function EfficientNet(name::Symbol; pretrain = false) | ||
| @assert name in keys(efficientnet_global_configs) | ||
| "`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))" | ||
|
|
||
| model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) | ||
| pretrain && loadpretrain!(model, string("efficientnet-", name)) | ||
|
|
||
| return model | ||
| end | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.