1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| def show_imgs(n_rows, n_cols, x_data, y_data, class_names): assert len(x_data) == len(y_data) assert n_rows * n_cols < len(x_data) plt.figure(figsize = (n_cols * 1.4, n_rows * 1.6)) for row in range(n_rows): for col in range(n_cols): index = n_cols * row + col plt.subplot(n_rows, n_cols, index+1) plt.imshow(x_data[index], cmap="binary", interpolation = 'nearest') plt.axis('off') plt.title(class_names[y_data[index]]) plt.show() class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt','Sneaker', 'Bag', 'Ankle boot'] show_imgs(3, 5, x_train, y_train, class_names)
|