20 lines
590 B
Python
20 lines
590 B
Python
from matplotlib import pyplot as plt
|
|
|
|
def show_images(images, n_row=5, n_col=5, figsize=[12,12]):
|
|
_, axs = plt.subplots(n_row, n_col, figsize=figsize)
|
|
axs = axs.flatten()
|
|
for img, ax in zip(images, axs):
|
|
ax.imshow(img, cmap='gray')
|
|
plt.show()
|
|
|
|
def dict_train_test_split(X_dict, y, ratio=0.9):
|
|
N_train = int(len(y)*ratio)
|
|
|
|
X_dict_train = {k:v[:N_train] for k,v in X_dict.items()}
|
|
y_train = y[:N_train]
|
|
|
|
X_dict_test = {k:v[N_train:] for k,v in X_dict.items()}
|
|
y_test = y[N_train:]
|
|
|
|
return X_dict_train, y_train, X_dict_test, y_test
|