Module livelossplot.outputs.matplotlib_plot

Expand source code
import math
from typing import Tuple, List, Dict, Optional, Callable

import warnings

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import clear_output
from livelossplot.main_logger import MainLogger, LogItem
from livelossplot.outputs.base_output import BaseOutput

class MatplotlibPlot(BaseOutput):
    """NOTE: Removed figsize and dynamix_x_axis."""
    def __init__(
        cell_size: Tuple[int, int] = (6, 4),
        max_cols: int = 2,
        max_epoch: int = None,
        skip_first: int = 2,
        extra_plots: List[Callable[[MainLogger], None]] = [],
        figpath: Optional[str] = None,
        after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
        before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None,
        after_plots: Optional[Callable[[plt.Figure], None]] = None,
            cell_size: size of one chart
            max_cols: maximal number of charts in one row
            max_epoch: maximum epoch on x axis
            skip_first: number of first steps to skip
            extra_plots: extra charts functions
            figpath path: to save figure
            after_subplot: function which will be called after every subplot
            before_plots: function which will be called before all subplots
            after_plots: function which will be called after all subplots
        self.cell_size = cell_size
        self.max_cols = max_cols
        self.skip_first = skip_first  # think about it
        self.extra_plots = extra_plots
        self.max_epoch = max_epoch
        self.figpath = figpath
        self.file_idx = 0  # now only for saving files
        self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
        self._before_plots = before_plots if before_plots else self._default_before_plots
        self._after_plots = after_plots if after_plots else self._default_after_plots

    def send(self, logger: MainLogger):
        """Draw figures with metrics and show"""
        log_groups = logger.grouped_log_history()

        max_rows = math.ceil((len(log_groups) + len(self.extra_plots)) / self.max_cols)

        fig, axes = plt.subplots(max_rows, self.max_cols)
        axes = axes.reshape(-1, self.max_cols)
        self._before_plots(fig, axes, len(log_groups))

        for group_idx, (group_name, group_logs) in enumerate(log_groups.items()):
            ax = axes[group_idx // self.max_cols, group_idx % self.max_cols]
            if any(len(logs) > 0 for name, logs in group_logs.items()):
                self._draw_metric_subplot(ax, group_logs, group_name=group_name, x_label=logger.step_names[group_name])

        for idx, extra_plot in enumerate(self.extra_plots):
            ax = axes[(len(log_groups) + idx) // self.max_cols, (len(log_groups) + idx) % self.max_cols]
            extra_plot(ax, logger)


    def _default_after_subplot(self, ax: plt.Axes, group_name: str, x_label: str):
        """Add title xlabel and legend to single chart
            ax: matplotlib Axes
            group_name: name of metrics group (eg. Accuracy, Recall)
            x_label: label of x axis (eg. epoch, iteration, batch)
        ax.legend(loc='center right')

    def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_groups: int) -> None:
        """Set matplotlib window properties
            fig: matplotlib Figure
            num_of_log_groups: number of log groups
        figsize_x = self.max_cols * self.cell_size[0]
        figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
        fig.set_size_inches(figsize_x, figsize_y)
        if num_of_log_groups < axes.size:
            for idx, ax in enumerate(axes[-1]):
                if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:

    def _default_after_plots(self, fig: plt.Figure):
        """Set properties after charts creation
            fig: matplotlib Figure
        if self.figpath is not None:
            self.file_idx += 1

    def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str):
            ax: matplotlib Axes
            group_logs: list of log items per each group name
            group_name: group name
            x_label: label for the x axis
        # There used to be skip first part, but I skip it first
        if self.max_epoch is not None:
            ax.set_xlim(0, self.max_epoch)

        for name, logs in group_logs.items():
            if len(logs) > 0:
                xs = [log.step for log in logs]
                ys = [log.value for log in logs]
                ax.plot(xs, ys, label=name)

        self._after_subplot(ax, group_name, x_label)

    def _not_inline_warning(self):
        backend = matplotlib.get_backend()
        if "backend_inline" not in backend:
                "livelossplot requires inline plots.\nYour current backend is: {}"
                "\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend)


class MatplotlibPlot (cell_size: Tuple[int, int] = (6, 4), max_cols: int = 2, max_epoch: int = None, skip_first: int = 2, extra_plots: List[Callable[[MainLogger], None]] = [], figpath: Optional[str] = None, after_subplot: Optional[Callable[[matplotlib.axes._axes.Axes, str, str], None]] = None, before_plots: Optional[Callable[[matplotlib.figure.Figure, numpy.ndarray, int], None]] = None, after_plots: Optional[Callable[[matplotlib.figure.Figure], None]] = None)

NOTE: Removed figsize and dynamix_x_axis.


size of one chart
maximal number of charts in one row
maximum epoch on x axis
number of first steps to skip
extra charts functions
figpath path: to save figure
function which will be called after every subplot
function which will be called before all subplots
function which will be called after all subplots
Expand source code
class MatplotlibPlot(BaseOutput):
    """NOTE: Removed figsize and dynamix_x_axis."""
    def __init__(
        cell_size: Tuple[int, int] = (6, 4),
        max_cols: int = 2,
        max_epoch: int = None,
        skip_first: int = 2,
        extra_plots: List[Callable[[MainLogger], None]] = [],
        figpath: Optional[str] = None,
        after_subplot: Optional[Callable[[plt.Axes, str, str], None]] = None,
        before_plots: Optional[Callable[[plt.Figure, np.ndarray, int], None]] = None,
        after_plots: Optional[Callable[[plt.Figure], None]] = None,
            cell_size: size of one chart
            max_cols: maximal number of charts in one row
            max_epoch: maximum epoch on x axis
            skip_first: number of first steps to skip
            extra_plots: extra charts functions
            figpath path: to save figure
            after_subplot: function which will be called after every subplot
            before_plots: function which will be called before all subplots
            after_plots: function which will be called after all subplots
        self.cell_size = cell_size
        self.max_cols = max_cols
        self.skip_first = skip_first  # think about it
        self.extra_plots = extra_plots
        self.max_epoch = max_epoch
        self.figpath = figpath
        self.file_idx = 0  # now only for saving files
        self._after_subplot = after_subplot if after_subplot else self._default_after_subplot
        self._before_plots = before_plots if before_plots else self._default_before_plots
        self._after_plots = after_plots if after_plots else self._default_after_plots

    def send(self, logger: MainLogger):
        """Draw figures with metrics and show"""
        log_groups = logger.grouped_log_history()

        max_rows = math.ceil((len(log_groups) + len(self.extra_plots)) / self.max_cols)

        fig, axes = plt.subplots(max_rows, self.max_cols)
        axes = axes.reshape(-1, self.max_cols)
        self._before_plots(fig, axes, len(log_groups))

        for group_idx, (group_name, group_logs) in enumerate(log_groups.items()):
            ax = axes[group_idx // self.max_cols, group_idx % self.max_cols]
            if any(len(logs) > 0 for name, logs in group_logs.items()):
                self._draw_metric_subplot(ax, group_logs, group_name=group_name, x_label=logger.step_names[group_name])

        for idx, extra_plot in enumerate(self.extra_plots):
            ax = axes[(len(log_groups) + idx) // self.max_cols, (len(log_groups) + idx) % self.max_cols]
            extra_plot(ax, logger)


    def _default_after_subplot(self, ax: plt.Axes, group_name: str, x_label: str):
        """Add title xlabel and legend to single chart
            ax: matplotlib Axes
            group_name: name of metrics group (eg. Accuracy, Recall)
            x_label: label of x axis (eg. epoch, iteration, batch)
        ax.legend(loc='center right')

    def _default_before_plots(self, fig: plt.Figure, axes: np.ndarray, num_of_log_groups: int) -> None:
        """Set matplotlib window properties
            fig: matplotlib Figure
            num_of_log_groups: number of log groups
        figsize_x = self.max_cols * self.cell_size[0]
        figsize_y = ((num_of_log_groups + 1) // self.max_cols + 1) * self.cell_size[1]
        fig.set_size_inches(figsize_x, figsize_y)
        if num_of_log_groups < axes.size:
            for idx, ax in enumerate(axes[-1]):
                if idx >= (num_of_log_groups + len(self.extra_plots)) % self.max_cols:

    def _default_after_plots(self, fig: plt.Figure):
        """Set properties after charts creation
            fig: matplotlib Figure
        if self.figpath is not None:
            self.file_idx += 1

    def _draw_metric_subplot(self, ax: plt.Axes, group_logs: Dict[str, List[LogItem]], group_name: str, x_label: str):
            ax: matplotlib Axes
            group_logs: list of log items per each group name
            group_name: group name
            x_label: label for the x axis
        # There used to be skip first part, but I skip it first
        if self.max_epoch is not None:
            ax.set_xlim(0, self.max_epoch)

        for name, logs in group_logs.items():
            if len(logs) > 0:
                xs = [log.step for log in logs]
                ys = [log.value for log in logs]
                ax.plot(xs, ys, label=name)

        self._after_subplot(ax, group_name, x_label)

    def _not_inline_warning(self):
        backend = matplotlib.get_backend()
        if "backend_inline" not in backend:
                "livelossplot requires inline plots.\nYour current backend is: {}"
                "\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend)



def send(self, logger: MainLogger)

Draw figures with metrics and show

Expand source code
def send(self, logger: MainLogger):
    """Draw figures with metrics and show"""
    log_groups = logger.grouped_log_history()

    max_rows = math.ceil((len(log_groups) + len(self.extra_plots)) / self.max_cols)

    fig, axes = plt.subplots(max_rows, self.max_cols)
    axes = axes.reshape(-1, self.max_cols)
    self._before_plots(fig, axes, len(log_groups))

    for group_idx, (group_name, group_logs) in enumerate(log_groups.items()):
        ax = axes[group_idx // self.max_cols, group_idx % self.max_cols]
        if any(len(logs) > 0 for name, logs in group_logs.items()):
            self._draw_metric_subplot(ax, group_logs, group_name=group_name, x_label=logger.step_names[group_name])

    for idx, extra_plot in enumerate(self.extra_plots):
        ax = axes[(len(log_groups) + idx) // self.max_cols, (len(log_groups) + idx) % self.max_cols]
        extra_plot(ax, logger)


Inherited members