📜  TensorFlow中的CIFAR-10和CIFAR-100数据集

📅  最后修改于: 2021-01-11 10:44:04             🧑  作者: Mango

TensorFlow中的CIFAR-10和CIFAR-100数据集

CIFAR-10(加拿大高级研究所)和CIFAR-100被标记为8000万个微型图像数据集的子集。它们由Alex Krizhevsky,Geoffrey Hinton和Vinod Nair收集。数据集分为五个训练批次和一个测试批次,每个批次具有10000张图像。

测试批次包含每个类别的1000张随机选择的图像。训练批次以随机顺序包含剩余图像,但是一些训练批次以随机顺序包含剩余图像,但是一些训练批次包含从一个类到另一类的更多图像。在它们之间,每个批次的培训批次正好包含5000张图像。

这些类将完全互斥。汽车和卡车之间不会重叠。汽车包括类似于轿车和越野车的东西。卡车类仅包括大型卡车,也不包括皮卡车。如果我们通过CIFAR数据集查看,我们就会意识到这不仅仅是鸟或猫的一种。鸟和猫类包含许多不同类型的鸟和猫。鸟和猫类提供许多种类的鸟和猫,它们的大小,颜色,放大率,不同的角度和不同的姿势都不同。

对于无穷无尽的数据集,我们可以通过多种方式来编写第一和第二。它只是没有那么多样化,最重要的是,无尽的数据集是一个灰度标量。 CIFAR数据集包含32张乘32张彩色图像,每张照片具有三个不同的色彩通道。现在我们最重要的问题是,在无尽的数据集上表现如此出色的LeNet模型是否足以对CIFAR数据集进行分类?

CIFAR-100数据集

就像CIFAR-10数据集一样。唯一的区别是它有100个类,每个类包含600个图像。每个班级有100张测试图像和500张训练图像。这100个类别分为20个超类,每个图像带有一个“粗”标签(它所属的超类),一个“精细”标签(它所属的类)和一个“精细”标签(该类)它所属的)。

CIFAR-100数据集中的以下类:

S. No Superclass Classes
1. Flowers Orchids, poppies, roses, sunflowers, tulips
2. Fish Aquarium fish, flatfish, ray, shark, trout
3. Aquatic mammals Beaver, dolphin, otter, seal, whale
4. food containers Bottles, bowls, cans, cups, plates
5. Household electrical devices Clock, lamp, telephone, television, computer keyboard
6. Fruit and vegetables Apples, mushrooms, oranges, pears, sweet peppers
7. Household furniture Table, Chair, couch, wardrobe, bed,
8. Insects bee, beetle, butterfly, caterpillar, cockroach
9. Large natural outdoor scenes Cloud, forest, mountain, plain, sea
10. Large human-made outdoor things Bridge, castle, house, road, skyscraper
11. Large carnivores Bear, leopard, lion, tiger, wolf
12. Medium-sized mammals Fox, porcupine, possum, raccoon, skunk
13. Large Omnivores and herbivores Camel, cattle, chimpanzee, elephant, kangaroo
14. Non-insect invertebrates Crab, lobster, snail, spider, worm
15. reptiles Crocodile, dinosaur, lizards, snake, turtle
16. trees Maple, oak, palm, pine, willow
17. people girl, man, women, baby, boy
18. Small mammals Hamster, rabbit, mouse, shrew, squirrel
19. Vehicles 1 Bicycle, bus, motorcycle, pickup truck, train
20. Vehicles 2 Lawn-mower, rocket, streetcar, tractor, tank

用例:使用TensorFlow在卷积神经网络的帮助下实现CIFAR10

现在,使用内置的卷积神经网络TensorFlow训练网络以对来自CIFAR10数据集的图像进行分类。

考虑以下流程图,以了解用例的工作原理:

安装必要的软件包:

pip3 install numpy tensorflow pickle

训练网络:

import numpy as np
import tensorflow as tf
from time import time
import math
from include .data import get_data_set
from include.model import model, lr
train_x, train_y= get_data_set("train")
test_x, test_y = get_data_set("test")
tf. set_random_seed(21)
x, y, output, y_pred_cls, global_step, learning_rate=model()
global_accuracy =0
epoch_start=0
#PARAM
_BATCH_SIZE=128
_EPOCH=60
_SAVE_PATH="./tensorboard/cifar-10-v1.0.0/"
#LOSS AND OPTIMIZER
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=output, labels=y))
optimizer=tf.train.AdamOptimizer(learning_rate= learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08). Minimize(loss, global_step=global_step)
#PREDICTION AND ACCURACY CALCULATION
correct_prediction=tf.equal(y_pred_cls, tf.argmax(y, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_predictiction, tf.float32))
# SAVER
merged = tf.summary.merge_all()
saver = tf.train.Saver()
sess = tf.Session()
train_writer= tf.summary.FileWriter(_SAVE_PATH, sess.graph)
try:
print(" Trying to restore last checkpoint?")
last_chk_path= tf.train.latest_checkpoint(checkpoint_dir=SAVE_PATH)
saver.restore(sess, save_path=last_chk_path)
print("Restored checkpoint from:", last_chk_path)
except ValueError:
print("Failed to restore checkpoint. Initializing variable instead.")
sess.run(tf.global_variables_initializer())
def train(epoch):
    global epoch_start
    epoch_start= time()
    batch_size=int(math.ceil(len(train_x)/_BATCH_SIZE))
    i_global = 0
for s in range(batch_size):
    batch_xs= train_x[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE]
    batch_ys = train_y[s*_BATCH_SIZE: (s+1)*_BATCH_SIZE]
   start_time= time()
i_global, _, batch_loss, batch_acc=sess.run( [global_step, optimizer, loss, accuracy],
feed_dict={x: batch_xs, y: batch_ys, learning_rate: lr(epoch)})
duration = time() - start_time
if s% 10== 0:
percentage = int(round((s/batch_size)*100))
bar_len=29
filled_len= int ((bar_len*int(percentage))/100)
bar='=' *filled_len + ?>' + ?-? * (bar_len _filled_len)
msg= "Global step: { :>5} - [{}] {:>3}% -acc: {:.{:>4f} - loss: {:.4f} -{:.1f} sample/sec"
print(msg.format(i_global, bar, percentage, batch_acc, batch_loss, _BATCH_SIZE/duration))
test_and_save(i_global, epoch)
def test_and_save(_global_step, epoch):
global global_accuracy
global epoch_start
i=0
predicted_class=np.zeroes(shape=len(test_x), dtype=np.int)
while i< len (test_x) : j=min(i+_BATCH_SIZE, len(test_x)) batch_xs=test_x[I:j, :] batch_ys=test_y[i:j,:] predicted_class[i:j]=sess.run(y_pred_cls, feed_dict=x:batch_xs, y: batch_ys, learning_rate: lr(epoch)} ) i=j correct= (np.argmax(test_y, axis=1) == predicted_class) acc = correct.mean()*100 correct_numbers = correct.sum() hours, rem = divmod(time() - epoch_start, 3600) minutes, seconds = divmod(rem, 60) mes = "
Epoch {} - accuracy: {: .2f}% ({}/{})- time: {:0>2}:{:0.2}:{:05.2f}"
print(mes.format((epoch+1), acc, correct_numbers, len(test_x), int(hours), int(minutes), seconds))
if global_accuracy != 0 and global_accuracy < acc: summary = tf.Summary(value=[ tf.Summary.Value(tag="Accuracy/test", simple_value=acc), ]) train_writer.add_summary(summary, _global_step) saver.save(sess, save_path=_SAVE_PATH, global_step=_global_step) mes = "This epoch receive better accuracy: {:.2f} > {:.2f}. Saving session... 
 print(mes.format((acc, global_accuracy))
global_accuracy = acc
elif global_accuracy==0:
global_accuracy=acc
print("################################################################
def main():
train_start=time()
for i in range(_EPOCH):
print(" Epoch: {}/{}".format(( i+1),_EPOCH))
train(i)
hours, rem=divmod(time()-train_start, 3600 minutes, seconds=divmod(rem,60)
mes= "Best accuracy per session: {:.2f}, time: {:0>2}:{:0>2}:{:05.2f}"
print(mes.format(global_accuracy, int(hours), int(minutes), seconds))
if _name_ =="_main_":
main()
sess.close()

输出:

Epoch: 60/60

Global step: 23070 - [>-----------------------------]   0% - acc: 0.9531 - loss: 1.5081 - 7045.4 sample/sec
Global step: 23080 - [>-----------------------------]   3% - acc: 0.9453 - loss: 1.5159 - 7147.6 sample/sec
Global step: 23090 - [=>----------------------------]   5% - acc: 0.9844 - loss: 1.4764 - 7154.6 sample/sec
Global step: 23100 - [==>---------------------------]   8% - acc: 0.9297 - loss: 1.5307 - 7104.4 sample/sec
Global step: 23110 - [==>---------------------------]  10% - acc: 0.9141 - loss: 1.5462 - 7091.4 sample/sec
Global step: 23120 - [===>--------------------------]  13% - acc: 0.9297 - loss: 1.5314 - 7162.9 sample/sec
Global step: 23130 - [====>-------------------------]  15% - acc: 0.9297 - loss: 1.5307 - 7174.8 sample/sec
Global step: 23140 - [=====>------------------------]  18% - acc: 0.9375 - loss: 1.5231 - 7140.0 sample/sec
Global step: 23150 - [=====>------------------------]  20% - acc: 0.9297 - loss: 1.5301 - 7152.8 sample/sec
Global step: 23160 - [======>-----------------------]  23% - acc: 0.9531 - loss: 1.5080 - 7112.3 sample/sec
Global step: 23170 - [=======>----------------------]  26% - acc: 0.9609 - loss: 1.5000 - 7154.0 sample/sec
Global step: 23180 - [========>---------------------]  28% - acc: 0.9531 - loss: 1.5074 - 6862.2 sample/sec
Global step: 23190 - [========>---------------------]  31% - acc: 0.9609 - loss: 1.4993 - 7134.5 sample/sec
Global step: 23200 - [=========>--------------------]  33% - acc: 0.9609 - loss: 1.4995 - 7166.0 sample/sec
Global step: 23210 - [==========>-------------------]  36% - acc: 0.9375 - loss: 1.5231 - 7116.7 sample/sec
Global step: 23220 - [===========>------------------]  38% - acc: 0.9453 - loss: 1.5153 - 7134.1 sample/sec
Global step: 23230 - [===========>------------------]  41% - acc: 0.9375 - loss: 1.5233 - 7074.5 sample/sec
Global step: 23240 - [============>-----------------]  43% - acc: 0.9219 - loss: 1.5387 - 7176.9 sample/sec
Global step: 23250 - [=============>----------------]  46% - acc: 0.8828 - loss: 1.5769 - 7144.1 sample/sec
Global step: 23260 - [==============>---------------]  49% - acc: 0.9219 - loss: 1.5383 - 7059.7 sample/sec
Global step: 23270 - [==============>---------------]  51% - acc: 0.8984 - loss: 1.5618 - 6638.6 sample/sec
Global step: 23280 - [===============>--------------]  54% - acc: 0.9453 - loss: 1.5151 - 7035.7 sample/sec
Global step: 23290 - [================>-------------]  56% - acc: 0.9609 - loss: 1.4996 - 7129.0 sample/sec
Global step: 23300 - [=================>------------]  59% - acc: 0.9609 - loss: 1.4997 - 7075.4 sample/sec
Global step: 23310 - [=================>------------]  61% - acc: 0.8750 - loss:1.5842 - 7117.8 sample/sec
Global step: 23320 - [==================>-----------]  64% - acc: 0.9141 - loss:1.5463 - 7157.2 sample/sec
Global step: 23330 - [===================>----------]  66% - acc: 0.9062 - loss: 1.5549 - 7169.3 sample/sec
Global step: 23340 - [====================>---------]  69% - acc: 0.9219 - loss: 1.5389 - 7164.4 sample/sec
Global step: 23350 - [====================>---------]  72% - acc: 0.9609 - loss: 1.5002 - 7135.4 sample/sec
Global step: 23360 - [=====================>--------]  74% - acc: 0.9766 - loss: 1.4842 - 7124.2 sample/sec
Global step: 23370 - [======================>-------]  77% - acc: 0.9375 - loss: 1.5231 - 7168.5 sample/sec
Global step: 23380 - [======================>-------]  79% - acc: 0.8906 - loss: 1.5695 - 7175.2 sample/sec
Global step: 23390 - [=======================>------]  82% - acc: 0.9375 - loss: 1.5225 - 7132.1 sample/sec
Global step: 23400 - [========================>-----]  84% - acc: 0.9844 - loss: 1.4768 - 7100.1 sample/sec
Global step: 23410 - [=========================>----]  87% - acc: 0.9766 - loss: 1.4840 - 7172.0 sample/sec
Global step: 23420 - [==========================>---]  90% - acc: 0.9062 - loss: 1.5542 - 7122.1 sample/sec
Global step: 23430 - [==========================>---]  92% - acc: 0.9297 - loss: 1.5313 - 7145.3 sample/sec
Global step: 23440 - [===========================>--]  95% - acc: 0.9297 - loss: 1.5301 - 7133.3 sample/sec
Global step: 23450 - [============================>-]  97% - acc: 0.9375 - loss: 1.5231 - 7135.7 sample/sec
Global step: 23460 - [=============================>] 100% - acc: 0.9250 - loss: 1.5362 - 10297.5 sample/sec

Epoch 60 - accuracy: 78.81% (7881/10000)
This epoch receive better accuracy: 78.81 > 78.78. Saving session...
##################################################################################################

在测试数据集上运行网络:

import numpy as np
import tensorflow as tf
from include.data import get_data_set
from include.model import model
test_x, test_y= get_data_set("test")
x, y, output, y_pred_cls, global_step, learning_rate =model()
_BATCH_SIZE = 128
_CLASS_SIZE = 10
_SAVE_PATH = "./tensorboard/cifar-10-v1.0.0/"
saver= tf.train.Saver()
Sess=tf.Session()
try;
 print(" Trying to restore last checkpoint ...")
last_chk_path = tf.train.latest_checkpoint(checkpoint_dir=_SAVE_PATH
saver.restore(sess, save_path=last_chk_path)
print("Restored checkpoint from:", last_chk_path)
expect ValueError:
print("
Failed to restore checkpoint. Initializing variables instead.")
sess.run(tf.global_variables_initializer())
def main():
i=0
predicted_class= np.zeros(shape=len(test_x), dtype=np.int)
while i< lens(test_x):
j=min(i+_BATCH_SIZE, len(test_x))
batch_xs=test_x[i:j,:]
batch_xs=test_y[i:j,:]
pre dicted_class[i:j] = sess.run(y_pred_cls, feed_dict={x: batch_xs, y: batch_ys})
i=j
corr ect = (np.argmax(test_y, axis=1) == predicted_class)
acc=correct.mean()*100
correct_numbers=correct.sum()
print()
print("Accuracy is on Test-Set: {0:.2f}% ({1} / {2})".format(acc, correct_numbers, len(test_x)))
if__name__=="__main__":
main()
sess.close()

简单输出

Trying to restore last checkpoint ...
Restored checkpoint from: ./tensorboard/cifar-10-v1.0.0/-23460

Accuracy on Test-Set: 78.81% (7881 / 10000)

训练时间

在这里,我们看到60个纪元需要多少时间:

Device Batch Time Accuracy[%]
NVidia 128 8m4s 79.12
Inteli77700HQ 128 3h30m 78.91