NOTE: This Use Case is not purposed for resource constrained devices.
Weather Classifcation¶
Credit: AITS Cainvas Community
Photo by Sergey Galtsev on Dribbble
Image tagging helps in selecting images based on content, especially useful in search engines and other similar applications. Here, we tag images based on the weather of the scene. There are two classes - cloudy, sunny.
In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from tensorflow.keras import layers, optimizers, models, preprocessing, losses, callbacks
import os
import random
from PIL import Image
import tensorflow as tf
import tensorflow.keras
In [2]:
!wget https://cainvas-static.s3.amazonaws.com/media/user_data/cainvas-admin/weather.zip
!unzip -qo weather.zip
!rm weather.zip
In [3]:
# Loading the dataset
path = 'weather/'
input_shape = (256, 256, 3) # default input shape while loading the images
batch = 64
# The train and test datasets
print("Train dataset")
train_ds = preprocessing.image_dataset_from_directory(path+'train', batch_size=batch, label_mode='binary')
print("Test dataset")
test_ds = preprocessing.image_dataset_from_directory(path+'test', batch_size=batch, label_mode='binary')
In [4]:
# How many samples in each class
for t in ['train', 'test']:
print('\n', t.upper())
for x in os.listdir(path + t):
print(x, ' - ', len(os.listdir(path + t + '/' + x)))
The train set is balanced while the test set is imbalanced. A confusion matrix can help in finding the accuracies.
In [5]:
# Looking into the class labels
class_names = train_ds.class_names
print("Train class names: ", train_ds.class_names)
print("Test class names: ", test_ds.class_names)
Visualization¶
In [6]:
num_samples = 4 # the number of samples to be displayed in each class
for x in class_names:
plt.figure(figsize=(20, 20))
filenames = os.listdir(path + 'train/' + x)
for i in range(num_samples):
ax = plt.subplot(1, num_samples, i + 1)
img = Image.open(path +'train/' + x + '/' + filenames[i])
plt.imshow(img)
plt.title(x)
plt.axis("off")