plotter.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. ''' Incremental-Classifier Learning
  2. Authors : Khurram Javed, Muhammad Talha Paracha
  3. Maintainer : Khurram Javed
  4. Lab : TUKL-SEECS R&D Lab
  5. Email : 14besekjaved@seecs.edu.pk '''
  6. import matplotlib
  7. import matplotlib.pyplot as plt
  8. plt.switch_backend('agg')
  9. MEDIUM_SIZE = 18
  10. font = {'family': 'sans-serif',
  11. 'weight': 'bold'}
  12. matplotlib.rc('xtick', labelsize=MEDIUM_SIZE)
  13. matplotlib.rc('ytick', labelsize=MEDIUM_SIZE)
  14. plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
  15. # matplotlib.rc('font', **font)
  16. from matplotlib import rcParams
  17. rcParams.update({'figure.autolayout': True})
  18. class Plotter():
  19. def __init__(self):
  20. import itertools
  21. # plt.figure(figsize=(12, 9))
  22. self.marker = itertools.cycle(('o', '+', "v", "^", "8", '.', '*'))
  23. self.handles = []
  24. self.lines = itertools.cycle(('--', '-.', '-', ':'))
  25. def plot(self, x, y, xLabel="Number of Classes", yLabel="Accuracy %", legend="none", title=None, error=None):
  26. self.x = x
  27. self.y = y
  28. plt.grid(color='0.89', linestyle='--', linewidth=1.0)
  29. if error is None:
  30. l, = plt.plot(x, y, linestyle=next(self.lines), marker=next(self.marker), label=legend, linewidth=3.0)
  31. else:
  32. l = plt.errorbar(x, y, yerr=error, capsize=4.0, capthick=2.0, linestyle=next(self.lines),
  33. marker=next(self.marker), label=legend, linewidth=3.0)
  34. self.handles.append(l)
  35. self.x_label = xLabel
  36. self.y_label = yLabel
  37. if title is not None:
  38. plt.title(title)
  39. def save_fig(self, path, xticks=105, title=None, yStart=0, xRange=0, yRange=10):
  40. if title is not None:
  41. plt.title(title)
  42. plt.legend(handles=self.handles)
  43. plt.ylim((yStart, 100 + 0.2))
  44. plt.xlim((0, xticks + .2))
  45. plt.ylabel(self.y_label)
  46. plt.xlabel(self.x_label)
  47. plt.yticks(list(range(yStart, 101, yRange)))
  48. print(list(range(yStart, 105, yRange)))
  49. plt.xticks(list(range(0, xticks + 1, xRange + int(xticks / 10))))
  50. plt.savefig(path + ".eps", format='eps')
  51. plt.gcf().clear()
  52. def save_fig2(self, path, xticks=105):
  53. plt.legend(handles=self.handles)
  54. plt.xlabel("Memory Budget")
  55. plt.ylabel("Average Incremental Accuracy")
  56. plt.savefig(path + ".jpg")
  57. plt.gcf().clear()
  58. def plotMatrix(self, epoch, path, img):
  59. plt.imshow(img, cmap='plasma', interpolation='nearest')
  60. plt.colorbar()
  61. plt.savefig(path + str(epoch) + ".svg", format='svg')
  62. plt.gcf().clear()
  63. def saveImage(self, img, path, epoch):
  64. from PIL import Image
  65. im = Image.fromarray(img)
  66. im.save(path + str(epoch) + ".jpg")
  67. if __name__ == "__main__":
  68. pl = Plotter()
  69. pl.plot([1, 2, 3, 4], [2, 3, 6, 2])
  70. pl.save_fig("test.jpg")