
.
In previous posts we’ve looked into the basic structure of the torch-dataframe package. In this post we’ll go through the mnist example that shows how to best integrate the dataframe with torchnet.
All posts in the torch-dataframe series
The getIterator
Everything interesting is located int the getIterator
function. This functions purpose is to select train, test, validate datasets and return an instance of the tnt.DatasetIterator
that returns a table with input
and target
tensors:
{ input = ..., target = ... }
The mnist dataset
The mnist dataset is loaded via the mnist package:
local mnist = require 'mnist' local mnist_dataset = mnist[mode .. 'dataset']()
The labels from the dataset are then converted to a Dataframe:
local df = Dataframe( Df_Dict{ label = mnist_dataset.label:totable(), row_id = torch.range(1, mnist_dataset.data:size(1)):totable() })
The image data is retrieved using an external resource just as you would for any external data storage:
-- Since the mnist package already has taken care of the data -- splitting we create a single subsetter df:create_subsets{ subsets = Df_Dict{core = 1}, data_retriever = function(row) return ext_resource[row.row_id] end, label_retriever = Df_Array("label") } local subset = df["/core"]
Note that we here create a single subset as the data is already split.
The iterators
The dataframe has two specialized iterator classes for setting up the iterator, Df_Iterator
and Df_ParallelIterator
. The difference is that Df_ParallelIterator
allows you to set up multiple threads (the nthread
argument) and take care of the external data loading there. Note: this is different from torchnet’s own parallel iterator that exports the entire dataset to the threads and then does everything within that thread. The reason for our approach is that this won’t work with the samplers and we believe that the extra cost is negligible as long as you don’t have all your data in the csv-file.
The plain iterator
Here we set up the external resource and then create the iterator that will have access to that resource:
ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1), mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double() return Df_Iterator{ dataset = subset, batch_size = 128, target_transform = function(val) return val + 1 end }
The parallel iterator
This is similar to the above but we most in addition load the packages required within each thread and also the external resource inside each thread:
return Df_ParallelIterator{ dataset = subset, batch_size = 128, init = function(idx) -- Load the libraries needed require 'torch' require 'Dataframe' -- Load the datasets external resource local mnist = require 'mnist' local mnist_dataset = mnist[mode .. 'dataset']() ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1), mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double() end, nthread = 2, target_transform = function(val) return val + 1 end }
The reset call
The torchnet engines don’t resample the dataset after each epoch and simply restart after completing my_dataset:size()
number of times. We therefore need to add a hook so that the reset_sampler
is envoked. This is only needed for those that require resetting (linear, ordered and permutation) but it is recommended to do this as a standard practice since it will make it easier to switch between the samplers. The hook belongs to the engine and is set a little further down the script:
engine.hooks.onEndEpoch = function(state) print("End epoch no " .. state.epoch) state.iterator.dataset:reset_sampler() end
A little about torchnet
As I’ve been adapting torch-dataframe to torchnet I’ve learned to appreciate its brilliant structure. The dataset layers let you build increasing complexity as needed and the possibilities are endless. The engine is elegant and understanding the hooks is trivial if you just look at the code (from the SGDEngine)
self.hooks("onStart", state) while state.epoch < state.maxepoch do state.network:training() self.hooks("onStartEpoch", state) for sample in state.iterator() do state.sample = sample self.hooks("onSample", state) state.network:forward(sample.input) self.hooks("onForward", state) state.criterion:forward(state.network.output, sample.target) self.hooks("onForwardCriterion", state) state.network:zeroGradParameters() if state.criterion.zeroGradParameters then state.criterion:zeroGradParameters() end state.criterion:backward(state.network.output, sample.target) self.hooks("onBackwardCriterion", state) state.network:backward(sample.input, state.criterion.gradInput) self.hooks("onBackward", state) assert(state.lrcriterion >= 0, 'lrcriterion should be positive or zero') if state.lrcriterion > 0 and state.criterion.updateParameters then state.criterion:updateParameters(state.lrcriterion) end assert(state.lr >= 0, 'lr should be positive or zero') if state.lr > 0 then state.network:updateParameters(state.lr) end state.t = state.t + 1 self.hooks("onUpdate", state) end state.epoch = state.epoch + 1 self.hooks("onEndEpoch", state) end self.hooks("onEnd", state)
I believe that the torch-dataframe does have it’s place in the infrastructure as it will allow you to better visualize your data. It also brings to the table some basic data operations where the simple as_categorical
will quickly allow you to understand the networks outputs.
Summary
In this post we’ve looked closer at the mnist example and what components it uses. Hopefully you can use this as a template in your own research.