-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Unet configurable #26
base: master
Are you sure you want to change the base?
Conversation
This sounds like. a good idea to support! I will take a look at the changes in the PR as well! |
src/model.jl
Outdated
Chain(Conv(kernel, in_chs=>out_chs,pad = (1, 1);init=_random_normal), | ||
BatchNormWrap(out_chs), | ||
x->leakyrelu.(x,0.2f0)) | ||
struct ConvBlock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be parameterised
struct ConvBlock | |
struct ConvBlock{T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/model.jl
Outdated
|
||
struct UNetUpBlock | ||
upsample | ||
function ConvBlock(in_channels, out_channels, kernel_sizes = [(3,3), (3,3)]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can mirror the Flux Conv
API here - this would mean we switch out the positions of the kernels, and make the channels as a Pair
. This does lose out on having default kernel pairs though. Are we expecting to not need to define the kernels much? I suppose we wouldn't need to, but maintaining that consistency in the API would be a plus. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good idea. There is a way to allow optional first arguments, since the number of arguments changes. Therefore we made the first argument now optional, followed by the mentioned Pair
.
src/model.jl
Outdated
stride=(2, 2);init=_random_normal), | ||
BatchNormWrap(out_chs), | ||
Dropout(p))) | ||
struct Downsample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment about paramterization as earlier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
src/model.jl
Outdated
conv_down_blocks | ||
conv_blocks | ||
up_blocks | ||
function Downsample(downsample_factor; pooling_type="max") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we took the function maxpool
or meanpool
directly, we can get rid of the conditional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. Changed accordingly.
```jldoctest | ||
``` | ||
""" | ||
function Unet(; # all arguments are named and ahve defaults |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about retaining some of the positional argument versions as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean as above to mirror the flux API? Or are there specific arguments you think should be positional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the channels can be made a Pair in => out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Co-authored-by: Dhairya Gandhi <[email protected]>
Even for conv chains the type can't be shared since different numbers of convolutions are permissible
Following. Where did we end up here? |
Hi @mkitti |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can confirm, got side tracked 😄
I've added a couple comments, but it looks mostly good! I would love to add this in!
```jldoctest | ||
``` | ||
""" | ||
function Unet(; # all arguments are named and ahve defaults |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel the channels can be made a Pair in => out
src/model.jl
Outdated
activation = NNlib.relu, | ||
final_activation = NNlib.relu, | ||
padding ="same", | ||
pooling_type ="max" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to take in the function directly if possible - that way we simply call the input function and users can specify their own polling if preferable
I think we addressed all your suggestions. Could you have another look @DhairyaLGandhi |
Hi @DhairyaLGandhi
I often need flexibility with the hyperparameters of the UNet, i.e. different numbers of downsamplings and convolutions, kernel sizes etc. So we started implementing this more configurable version.
This currently allows for a UNet with configurable
Let me know if this is something you would be interested in including here.