In this post, we are gonna make a machine learning model but instead of using a self made neural network, we would use Mobilenet model which is already trained a hundred thousand times and has thousands of neurons and interconnected layers.
So we just have to download the model as code mentioned in the code snippet and we would not have to make a complex neural network and train it for many epochs. The models are uploaded on the internet and you could also use other trained models like VGG16 and many other.
In this code, I am gonna predict dog breed with Mobilenet model. You can download any picture from the internet and try to predict it. This model is no human, So it would not be that much accurate.
This model would give you top 5 probabilities of your input image and would be somewhat more accurate than your self made neural networks.
Preview:
import matplotlib.pyplot as plt
import numpy as np
from numpy import random
import tensorflow as tf
from tensorflow.keras.applications import imagenet_utils
from tensorflow.keras.applications.imagenet_utils import decode_predictions
from tensorflow import keras
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.metrics import categorical_crossentropy
# download the mobilenet function
mobile = tf.keras.applications.mobilenet.MobileNet()
# create a function to preprocess the image to required size and dimension
def prepare_image(file):
img_path = "C:/data/cats_vs_dogs/test/dog/" # select the image directory
img = image.load_img(img_path + file, target_size=(224,224))
img_array = image.img_to_array(img)
img_array_expand_dims = np.expand_dims(img_array, axis=0)
return keras.applications.mobilenet.preprocess_input(img_array_expand_dims)
preprocessed_image = prepare_image("126.jpg") # enter the name of any image you want to predict
predictions = mobile.predict(preprocessed_image)
results = imagenet_utils.decode_predictions(predictions)
print(results)
0 Comments