Predictions

The package supports multiple options to obtain predictions on your testDataset (only TFRecords Format Dataset).
 

Supported Methods for obtaining predictions ->

  • get_predictions

  • get_predictions_tfrec

  • ensemble_predictions

    • Model Averaging

    • Model Weighted

    • Test Time Augmentations (Coming Soon)

Get Predictions

Obtain predictions on the test dataset which has been obtained using get_test_dataset() from begin_tpu.

Get the required function using the following line of code.

 

 


Use the following function definition when you have the obtained the Test Dataset and the model.
This is usually the best option when you have one model and want to obtain the predictions in the same session. That's why, you need to have the datasets loaded before.

Get Predictions TFRec

This is usually helpful when you have trained model weights from a different session and want to obtain predictions in a different session. Usually beneficial if there are multiple models from whom predictions are to be obtained.

 

However, you are free to explore the various possibilities.

Ensemble Predictions

 

 

Under ensemble predictions, you can choose either 'Model Averaging' or 'Model Weighted' method. 

Arguments:-

    models_list - list of models' instances to obtain their ensembled predictions

    testTFdataset - the loaded unlabeled dataset or the "test dataset" obtain from tfrecords (unlabeled).

    ensemble_type - Default, 'Model Averaging'. Options, 'Model Averaging' or 'Model Weighted'.

    classification_type - Default, 'binary'. Select between 'binary' or 'multiclass'.

    weights - Default, None. Pretrained weights if any.

Returns:-

   Saves a csv file as output in accordance with the ensemble_type chosen.

     

from quick_ml.predictions import get_predictions

predictions = get_predictions_tfrec(GCS_DS_PATH, test_tfrec_path, BATCH_SIZE, model, output_filename)

predictions = get_predictions(testTFdataset, model, output_filename)

predictions =

ensemble_predictions(models_list, testTFdataset, ensemble_type = 'Model Averaging', classification_type = 'mutliclass', weights = None)

Was this helpful?

from quick_ml.predictions import get_predictions_tfrec

from quick_ml.predictions import ensemble_predictions