Module livelossplot.outputs.matplotlib_subplots
Expand source code
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
class BaseSubplot:
def __init__(self):
pass
def draw(self, *args, **kwargs):
raise Exception("Not implemented")
def __call__(self, *args, **kwargs):
self.draw(*args, **kwargs)
class LossSubplot(BaseSubplot):
"""To rewrire, this one now won't work"""
def __init__(
self, metric, title="", series_fmt={
'training': '{}',
'validation': 'val_{}'
}, skip_first=2, max_epoch=None
):
super().__init__(self)
self.metric = metric
self.title = title
self.series_fmt = series_fmt
self.skip_first = skip_first
self.max_epoch = max_epoch
raise NotImplementedError()
def _how_many_to_skip(self, log_length, skip_first):
if log_length < skip_first:
return 0
elif log_length < 2 * skip_first:
return log_length - skip_first
else:
return skip_first
def draw(self, logs):
skip = self._how_many_to_skip(len(logs), self.skip_first)
if self.max_epoch is not None:
plt.xlim(1 + skip, self.max_epoch)
for serie_label, serie_fmt in self.series_fmt.items():
serie_metric_name = serie_fmt.format(self.metric)
serie_metric_logs = [
(log.get('_i', i + 1), log[serie_metric_name])
for i, log in enumerate(logs[skip:]) if serie_metric_name in log
]
if len(serie_metric_logs) > 0:
xs, ys = zip(*serie_metric_logs)
plt.plot(xs, ys, label=serie_label)
plt.title(self.title)
plt.xlabel('epoch')
plt.legend(loc='center right')
class Plot1D(BaseSubplot):
def __init__(self, model, X, Y):
super().__init__(self)
self.model = model
self.X = X
self.Y = Y
def predict(self, model, X):
# e.g. model(torch.fromnumpy(X)).detach().numpy()
return model.predict(X)
def draw(self, *args, **kwargs):
plt.plot(self.X, self.Y, 'r.', label="Ground truth")
plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model")
plt.title("Prediction")
plt.legend(loc='lower right')
class Plot2d(BaseSubplot):
def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25, device='cpu'):
super().__init__()
self.model = model
self.X = X
self.Y = Y
self.X_test, self.Y_test = valiation_data
# add size assertions
self.cm_bg = plt.cm.RdBu
self.cm_points = ListedColormap(['#FF0000', '#0000FF'])
x_min = X[:, 0].min() - margin
x_max = X[:, 0].max() + margin
y_min = X[:, 1].min() - margin
y_max = X[:, 1].max() + margin
self.xx, self.yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
self.torch_device = device
def _predict_pytorch(self, model, x_numpy):
import torch
x = torch.from_numpy(x_numpy).to(self.torch_device).float()
return model(x).softmax(dim=1).detach().cpu().numpy()
def predict(self, model, X):
# e.g. model(torch.fromnumpy(X)).detach().numpy()
return model.predict(X)
def send(self, logger):
Z = self._predict_pytorch(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1]
Z = Z.reshape(self.xx.shape)
plt.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8)
plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points)
if self.X_test is not None:
plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3)
Classes
class BaseSubplot
-
Expand source code
class BaseSubplot: def __init__(self): pass def draw(self, *args, **kwargs): raise Exception("Not implemented") def __call__(self, *args, **kwargs): self.draw(*args, **kwargs)
Subclasses
Methods
def draw(self, *args, **kwargs)
-
Expand source code
def draw(self, *args, **kwargs): raise Exception("Not implemented")
class LossSubplot (metric, title='', series_fmt={'training': '{}', 'validation': 'val_{}'}, skip_first=2, max_epoch=None)
-
To rewrire, this one now won't work
Expand source code
class LossSubplot(BaseSubplot): """To rewrire, this one now won't work""" def __init__( self, metric, title="", series_fmt={ 'training': '{}', 'validation': 'val_{}' }, skip_first=2, max_epoch=None ): super().__init__(self) self.metric = metric self.title = title self.series_fmt = series_fmt self.skip_first = skip_first self.max_epoch = max_epoch raise NotImplementedError() def _how_many_to_skip(self, log_length, skip_first): if log_length < skip_first: return 0 elif log_length < 2 * skip_first: return log_length - skip_first else: return skip_first def draw(self, logs): skip = self._how_many_to_skip(len(logs), self.skip_first) if self.max_epoch is not None: plt.xlim(1 + skip, self.max_epoch) for serie_label, serie_fmt in self.series_fmt.items(): serie_metric_name = serie_fmt.format(self.metric) serie_metric_logs = [ (log.get('_i', i + 1), log[serie_metric_name]) for i, log in enumerate(logs[skip:]) if serie_metric_name in log ] if len(serie_metric_logs) > 0: xs, ys = zip(*serie_metric_logs) plt.plot(xs, ys, label=serie_label) plt.title(self.title) plt.xlabel('epoch') plt.legend(loc='center right')
Ancestors
Methods
def draw(self, logs)
-
Expand source code
def draw(self, logs): skip = self._how_many_to_skip(len(logs), self.skip_first) if self.max_epoch is not None: plt.xlim(1 + skip, self.max_epoch) for serie_label, serie_fmt in self.series_fmt.items(): serie_metric_name = serie_fmt.format(self.metric) serie_metric_logs = [ (log.get('_i', i + 1), log[serie_metric_name]) for i, log in enumerate(logs[skip:]) if serie_metric_name in log ] if len(serie_metric_logs) > 0: xs, ys = zip(*serie_metric_logs) plt.plot(xs, ys, label=serie_label) plt.title(self.title) plt.xlabel('epoch') plt.legend(loc='center right')
class Plot1D (model, X, Y)
-
Expand source code
class Plot1D(BaseSubplot): def __init__(self, model, X, Y): super().__init__(self) self.model = model self.X = X self.Y = Y def predict(self, model, X): # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X) def draw(self, *args, **kwargs): plt.plot(self.X, self.Y, 'r.', label="Ground truth") plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model") plt.title("Prediction") plt.legend(loc='lower right')
Ancestors
Methods
def draw(self, *args, **kwargs)
-
Expand source code
def draw(self, *args, **kwargs): plt.plot(self.X, self.Y, 'r.', label="Ground truth") plt.plot(self.X, self.predict(self.model, self.X), '-', label="Model") plt.title("Prediction") plt.legend(loc='lower right')
def predict(self, model, X)
-
Expand source code
def predict(self, model, X): # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X)
class Plot2d (model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25, device='cpu')
-
Expand source code
class Plot2d(BaseSubplot): def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25, device='cpu'): super().__init__() self.model = model self.X = X self.Y = Y self.X_test, self.Y_test = valiation_data # add size assertions self.cm_bg = plt.cm.RdBu self.cm_points = ListedColormap(['#FF0000', '#0000FF']) x_min = X[:, 0].min() - margin x_max = X[:, 0].max() + margin y_min = X[:, 1].min() - margin y_max = X[:, 1].max() + margin self.xx, self.yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) self.torch_device = device def _predict_pytorch(self, model, x_numpy): import torch x = torch.from_numpy(x_numpy).to(self.torch_device).float() return model(x).softmax(dim=1).detach().cpu().numpy() def predict(self, model, X): # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X) def send(self, logger): Z = self._predict_pytorch(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1] Z = Z.reshape(self.xx.shape) plt.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8) plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) if self.X_test is not None: plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3)
Ancestors
Methods
def predict(self, model, X)
-
Expand source code
def predict(self, model, X): # e.g. model(torch.fromnumpy(X)).detach().numpy() return model.predict(X)
def send(self, logger)
-
Expand source code
def send(self, logger): Z = self._predict_pytorch(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1] Z = Z.reshape(self.xx.shape) plt.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8) plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) if self.X_test is not None: plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3)