-
-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transfer learning ResNet #395
Conversation
fix typos and add context
fix typos and add context
fix typos and add context
Would be ready for review. |
function getindex(data::ImageContainer, idx::Int) | ||
path = data.img[idx] | ||
img = Images.load(path) | ||
img = apply(tfm, Image(img)) | ||
img = permutedims(channelview(RGB.(itemdata(img))), (3, 2, 1)) | ||
img = Float32.(img) | ||
name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2") | ||
y = name_to_idx[name] | ||
return img, Flux.onehotbatch(y, 1:3) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of applying the transformation in getindex
it would be better to showcase MLUtils.mapobs
. The advantage is that it is a pattern that can be used with any dataset. An example with mnist is given in the README here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is what you suggest to have a sepration between a minimal getindex
and a transform function, so something like:
function getindex(data::ImageContainer, idx::Int)
path = data.img[idx]
x = Images.load(path)
name = replace(path, r"(.+)\\(.+)\\(.+_\d+)\.jpg" => s"\2")
y = name_to_idx[name]
return (x = x, y = y)
end
function img_transform(batch)
img = apply(tfm, Image(batch[:x]))
img = permutedims(channelview(RGB.(itemdata(img))), (3, 2, 1))
x = Float32.(img)
y = Flux.onehotbatch(batch[:y], 1:3)
return (x, y)
end
dtrain = Flux.mapobs(img_transform, ImageContainer(imgs[1:2700]))
That's out of scope for thie PR, but benchmarking to validate there wasn't any performance difference, both approaches perform essentially the same, but performance degrades following each iteration:
function data_loop(data)
count = 0
for (x, y) in data
count += size(y, 2)
end
@info count
end
@btime data_loop($dtrain)
2.778 s (505360 allocations: 6.86 GiB)
4.000 s (505459 allocations: 6.86 GiB)
5.303 s (505461 allocations: 6.86 GiB)
5.423 s (505424 allocations: 6.86 GiB)
That's with MLUtils 2.11, so may be ignored until tutorial can be updated to latest Flux/MLUtils versions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is what you suggest to have a sepration between a minimal getindex and a transform function
yes
but performance degrades following each iteration
this is awful, if you have some time please open an issue MLUtils.jl
For this PR do as you prefer regarding the use of mapobs or not
```julia | ||
function train_epoch!(m_infer, m_tune; ps, opt, dtrain) | ||
for (x, y) in dtrain | ||
infer = m_infer(x) | ||
grads = gradient(ps) do | ||
Flux.Losses.logitcrossentropy(m_tune(infer), y) | ||
end | ||
update!(opt, ps, grads) | ||
end | ||
end | ||
``` | ||
|
||
```julia | ||
ps = Flux.params(m_tune); | ||
opt = Adam(3e-4) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use the new Optmisers.jl interface instead of the params-based one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metalhead compat forces use of Flux v0.13.4, for which my understanding was that the new Optimisers.jl wasn't yet in place.
There were some pending issues on weights import with Metalhead which once fixed, should allow to bump Flux compat, and I'd migrate the tutorial to Optimisers/explicit gradients once done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new Metalhed version just got released
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating to latest Flux brought some new challenges with MLUtils!
Creating the DataLoaders with parallel=true
results in stalled / hanging julia session, at least on Windows: https://github.com/jeremiedb/model-zoo/blob/51e6a091aa9d790d3162cfe8a106ad45aa0beac1/tutorials/transfer_learning/transfer_learning.jl#L35-L56
This occurs when running manually from the REPL, but works fine when launched as a script. Looks like it might be related to JuliaML/MLUtils.jl#142
Another "annoyance" is the need to have a collect
in the getobs
recipe, as mentionned in JuliaML/MLUtils.jl#139, as it seems opposite the the typical Julia pattern where explicit collect
isn't typically needed.
Would you be fine moving forward with the new Flux/MLUtils/explicit gradient approach, considering the above few caveats? I'd be fine going forward with latest versions and get these few gotchas ironed out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR has been updated to explicit gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CarloLucibello Do you feel this explicit gradient refresh is robust enough to go forward, or should first figure out JuliaML/MLUtils.jl#148 and others cavats?
v0.13 + as it does not use implicit parameters
I am not too sure what is the desired target format for the tutorials, notably as I noticed that associated tutorials like the A 60 Minute Blitz is no longer directly visible on https://fluxml.ai/.
However, I thought it would still be relevant to have blog post stlye tutorials, similar to the [DataLoader] one: https://github.com/FluxML/model-zoo/tree/master/tutorials/dataloader.
As such, I made a complete transfer learning tutorial for vision in such format, which in the meatime shows how to use custom data container and data augmentation, which seems fairly commonly asked for.
If such format is deemed of interest, I'll add a some discussion elements to enrich a bit the explanation of each of the code blocks.
I'd also suggest that I replace exisiting transfer_learning.jl and dataloader.jl by the single self contained script that allows to directly run the tutorial presented in the README.