Quantcast
Channel: G-Forge
Viewing all articles
Browse latest Browse all 75

Integration between torchnet and torch-dataframe – a closer look at the mnist example

$
0
0
It's all about the numbers and getting the tensors right. The image is cc by David Asch .
It’s all about the numbers and getting the tensors right. The image is cc by David Asch
.

In previous posts we’ve looked into the basic structure of the torch-dataframe package. In this post we’ll go through the mnist example that shows how to best integrate the dataframe with torchnet.

All posts in the torch-dataframe series

  1. Intro to the torch-dataframe
  2. Modifications
  3. Subsetting
  4. The mnist example
  5. Multilabel classification

The getIterator

Everything interesting is located int the getIterator function. This functions purpose is to select train, test, validate datasets and return an instance of the tnt.DatasetIterator that returns a table with input and target tensors:

{
  input = ...,
  target = ...
}

The mnist dataset

The mnist dataset is loaded via the mnist package:

local mnist = require 'mnist'
local mnist_dataset = mnist[mode .. 'dataset']()

The labels from the dataset are then converted to a Dataframe:

local df = Dataframe(
  Df_Dict{
    label = mnist_dataset.label:totable(),
    row_id = torch.range(1, mnist_dataset.data:size(1)):totable()
  })

The image data is retrieved using an external resource just as you would for any external data storage:

-- Since the mnist package already has taken care of the data
--  splitting we create a single subsetter
df:create_subsets{
  subsets = Df_Dict{core = 1},
  data_retriever = function(row)
    return ext_resource[row.row_id]
  end,
  label_retriever = Df_Array("label")
}
local subset = df["/core"]

Note that we here create a single subset as the data is already split.

The iterators

The dataframe has two specialized iterator classes for setting up the iterator, Df_Iterator and Df_ParallelIterator. The difference is that Df_ParallelIterator allows you to set up multiple threads (the nthread argument) and take care of the external data loading there. Note: this is different from torchnet’s own parallel iterator that exports the entire dataset to the threads and then does everything within that thread. The reason for our approach is that this won’t work with the samplers and we believe that the extra cost is negligible as long as you don’t have all your data in the csv-file.

The plain iterator

Here we set up the external resource and then create the iterator that will have access to that resource:

ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1),
  mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double()

return Df_Iterator{
  dataset = subset,
  batch_size = 128,
  target_transform = function(val)
    return val + 1
  end
}

The parallel iterator

This is similar to the above but we most in addition load the packages required within each thread and also the external resource inside each thread:

return Df_ParallelIterator{
  dataset = subset,
  batch_size = 128,
  init = function(idx)
    -- Load the libraries needed
    require 'torch'
    require 'Dataframe'

    -- Load the datasets external resource
    local mnist = require 'mnist'
    local mnist_dataset = mnist[mode .. 'dataset']()
    ext_resource = mnist_dataset.data:reshape(mnist_dataset.data:size(1),
      mnist_dataset.data:size(2) * mnist_dataset.data:size(3)):double()
  end,
  nthread = 2,
  target_transform =  function(val)
    return val + 1
  end
}

The reset call

The torchnet engines don’t resample the dataset after each epoch and simply restart after completing my_dataset:size() number of times. We therefore need to add a hook so that the reset_sampler is envoked. This is only needed for those that require resetting (linear, ordered and permutation) but it is recommended to do this as a standard practice since it will make it easier to switch between the samplers. The hook belongs to the engine and is set a little further down the script:

engine.hooks.onEndEpoch = function(state)
  print("End epoch no " .. state.epoch)
  state.iterator.dataset:reset_sampler()
end

A little about torchnet

As I’ve been adapting torch-dataframe to torchnet I’ve learned to appreciate its brilliant structure. The dataset layers let you build increasing complexity as needed and the possibilities are endless. The engine is elegant and understanding the hooks is trivial if you just look at the code (from the SGDEngine)

self.hooks("onStart", state)
while state.epoch < state.maxepoch do
   state.network:training()

   self.hooks("onStartEpoch", state)
   for sample in state.iterator() do
      state.sample = sample
      self.hooks("onSample", state)

      state.network:forward(sample.input)
      self.hooks("onForward", state)
      state.criterion:forward(state.network.output, sample.target)
      self.hooks("onForwardCriterion", state)

      state.network:zeroGradParameters()
      if state.criterion.zeroGradParameters then
         state.criterion:zeroGradParameters()
      end

      state.criterion:backward(state.network.output, sample.target)
      self.hooks("onBackwardCriterion", state)
      state.network:backward(sample.input, state.criterion.gradInput)
      self.hooks("onBackward", state)

      assert(state.lrcriterion >= 0, 'lrcriterion should be positive or zero')
      if state.lrcriterion > 0 and state.criterion.updateParameters then
         state.criterion:updateParameters(state.lrcriterion)
      end
      assert(state.lr >= 0, 'lr should be positive or zero')
      if state.lr > 0 then
         state.network:updateParameters(state.lr)
      end
      state.t = state.t + 1
      self.hooks("onUpdate", state)
   end
   state.epoch = state.epoch + 1
   self.hooks("onEndEpoch", state)
end
self.hooks("onEnd", state)

I believe that the torch-dataframe does have it’s place in the infrastructure as it will allow you to better visualize your data. It also brings to the table some basic data operations where the simple as_categorical will quickly allow you to understand the networks outputs.

Summary

In this post we’ve looked closer at the mnist example and what components it uses. Hopefully you can use this as a template in your own research.

Flattr this!


Viewing all articles
Browse latest Browse all 75

Trending Articles