
Predictions
The package supports multiple options to obtain predictions on your testDataset (only TFRecords Format Dataset).
Supported Methods for obtaining predictions ->
-
get_predictions
-
get_predictions_tfrec
-
ensemble_predictions
-
Model Averaging
-
Model Weighted
-
Test Time Augmentations (Coming Soon)
-
Get Predictions
Obtain predictions on the test dataset which has been obtained using get_test_dataset() from begin_tpu.
Get the required function using the following line of code.
Use the following function definition when you have the obtained the Test Dataset and the model.
This is usually the best option when you have one model and want to obtain the predictions in the same session. That's why, you need to have the datasets loaded before.
Get Predictions TFRec
This is usually helpful when you have trained model weights from a different session and want to obtain predictions in a different session. Usually beneficial if there are multiple models from whom predictions are to be obtained.
However, you are free to explore the various possibilities.
Ensemble Predictions
Under ensemble predictions, you can choose either 'Model Averaging' or 'Model Weighted' method.
Arguments:-
models_list - list of models' instances to obtain their ensembled predictions
testTFdataset - the loaded unlabeled dataset or the "test dataset" obtain from tfrecords (unlabeled).
ensemble_type - Default, 'Model Averaging'. Options, 'Model Averaging' or 'Model Weighted'.
classification_type - Default, 'binary'. Select between 'binary' or 'multiclass'.
weights - Default, None. Pretrained weights if any.
Returns:-
Saves a csv file as output in accordance with the ensemble_type chosen.
from quick_ml.predictions import get_predictions
predictions = get_predictions_tfrec(GCS_DS_PATH, test_tfrec_path, BATCH_SIZE, model, output_filename)
predictions = get_predictions(testTFdataset, model, output_filename)
predictions =
ensemble_predictions(models_list, testTFdataset, ensemble_type = 'Model Averaging', classification_type = 'mutliclass', weights = None)
Was this helpful?
from quick_ml.predictions import get_predictions_tfrec
from quick_ml.predictions import ensemble_predictions