import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import numpy as np csv = 'lr_loss.csv' def parse_csv(): with open(csv, 'r') as f: lines = f.readlines() lines = lines[1:] lr = [] epochs = [] loss = [] for line in lines: line = line.strip() line = line.split(',') lr.append(float(line[0])) epochs.append(int(line[1])) loss.append(float(line[2])) return (lr, loss, epochs) lr, loss, epochs = parse_csv() plt.plot(lr, loss) plt.yscale('log') plt.xlabel('lr') plt.ylabel('loss') plt.title('lr-loss') # plt.show() plt.plot(epochs, loss) plt.yscale('log') plt.xlabel('epochs') plt.ylabel('loss') plt.title('epochs-loss') # plt.show() fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.plot_trisurf(np.array(lr), np.array(epochs), np.array(loss), linewidth=0, antialiased=False) ax.set_xlabel('lr') ax.set_ylabel('epochs') ax.set_zlabel('loss') min_loss = min(loss) min_loss_index = loss.index(min_loss) min_lr = lr[min_loss_index] min_epoch = epochs[min_loss_index] ax.scatter(min_lr, min_epoch, min_loss, color='r') plt.title('lr-epoch-loss') plt.show()