Training¶
GQ-CNN training classes are never accessed directly, but through a lightweight factory function that returns the corresponding class depending on the specified backend.
$ from gqcnn import get_gqcnn_trainer
$
$ backend = 'tf'
$ my_trainer = get_gqcnn_trainer(backend)(<class initializer args>)
-
gqcnn.
get_gqcnn_trainer
(backend='tf')¶ Get the GQ-CNN Trainer for the provided backend.
Note
Currently only TensorFlow is supported.
- Parameters
backend (str) – The backend to use, currently only “tf” is supported.
- Returns
GQ-CNN Trainer with TensorFlow backend.
- Return type
GQCNNTrainerTF¶
-
class
gqcnn.training.tf.
GQCNNTrainerTF
(gqcnn, dataset_dir, split_name, output_dir, config, name=None, progress_dict=None, verbose=True)¶ Bases:
object
Trains a GQ-CNN with Tensorflow backend.
-
__init__
(gqcnn, dataset_dir, split_name, output_dir, config, name=None, progress_dict=None, verbose=True)¶ - Parameters
gqcnn (
GQCNN
) – Grasp quality neural network to optimize.dataset_dir (str) – Path to the training/validation dataset.
split_name (str) – Name of the split to train on.
output_dir (str) – Path to save the model output.
config (dict) – Dictionary of configuration parameters.
name (str) – Name of the the model.
-
train
()¶ Perform optimization.
-
finetune
(base_model_dir)¶ Perform fine-tuning.
- Parameters
base_model_dir (str) – Path to the pre-trained base model to use.
-