def draw_images_grid_with_labels(data, nrows, figsize=(12, 12), **subplots_adjust_param):
import matplotlib.pyplot as plt
# fig, axes = plt.subplots(nrows, len(data) // nrows)
# for idx, item in enumerate(data):
# i = idx % nrows # Get subplot row
# j = idx // nrows # Get subplot column
# # print(data[idx][0][0].shape)
# img = data[idx][0][0].permute(1, 2, 0)
# label = data[idx][1]
# axes[i, j].imshow(img)
# axes[i, j].set_title(label)
# plt.subplots_adjust(wspace=0, hspace=0)
# # plt.tight_layout()
# plt.show()
_, axes = plt.subplots(nrows, len(data) // nrows, figsize=figsize)
axes = axes.flatten()
for idx, (img, axe) in enumerate(zip(data, axes)):
img = data[idx][0][0].permute(1, 2, 0)
label = data[idx][1]
axe.imshow(img)
axe.set_title(label)
plt.subplots_adjust(**subplots_adjust_param)
# plt.subplot_tool()
plt.show()
params = {
"left": 0.125, # the left side of the subplots of the figure
"right": 0.9, # the right side of the subplots of the figure
"bottom": 0, # the bottom of the subplots of the figure
"top": 0.3, # the top of the subplots of the figure
"wspace": 0.3, # the amount of width reserved for space between subplots,
# expressed as a fraction of the average axis width
"hspace": 0.3, # the amount of height reserved for space between subplots,
# expressed as a fraction of the average axis height
}
draw_images_grid_with_labels(res, 2, **params)
需要根据需求调整 params,实例如下
原文:https://www.cnblogs.com/konosubaakua/p/14687936.html