Training

GQ-CNN training classes are never accessed directly, but through a lightweight factory function that returns the corresponding class depending on the specified backend.

$ from gqcnn import get_gqcnn_trainer
$
$ backend = 'tf'
$ my_trainer = get_gqcnn_trainer(backend)(<class initializer args>)
gqcnn.get_gqcnn_trainer(backend='tf')

Get the GQ-CNN Trainer for the provided backend.

Note

Currently only TensorFlow is supported.

Parameters

backend (str) – The backend to use, currently only “tf” is supported.

Returns

GQ-CNN Trainer with TensorFlow backend.

Return type

gqcnn.training.tf.GQCNNTrainerTF

GQCNNTrainerTF

class gqcnn.training.tf.GQCNNTrainerTF(gqcnn, dataset_dir, split_name, output_dir, config, name=None, progress_dict=None, verbose=True)

Bases: object

Trains a GQ-CNN with Tensorflow backend.

__init__(gqcnn, dataset_dir, split_name, output_dir, config, name=None, progress_dict=None, verbose=True)
Parameters
  • gqcnn (GQCNN) – Grasp quality neural network to optimize.

  • dataset_dir (str) – Path to the training/validation dataset.

  • split_name (str) – Name of the split to train on.

  • output_dir (str) – Path to save the model output.

  • config (dict) – Dictionary of configuration parameters.

  • name (str) – Name of the the model.

train()

Perform optimization.

finetune(base_model_dir)

Perform fine-tuning.

Parameters

base_model_dir (str) – Path to the pre-trained base model to use.