Source code for trajOptLib.plot.compare

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2018 Gao Tang <gt70@duke.edu>
#
# Distributed under terms of the MIT license.

"""
compare.py

Extensively used when I want to compare several samples
"""
import numpy as np
import matplotlib.pyplot as plt
from .common import plotkwargs, getColorCycle, get3dAxis, getIndAlongAxis, scatterkwargs


[docs]def compare(arr, x=None, ax=None, transpose=False, show=False, **kwargs): """Given a matrix or 3d tensor, do some comparison arr is 3d tensor or list of 2d array x is the x-coordinate for each Cat, so it is a list of 1d nparray trnspose = True make the data be data by sample show = False means not plot immediately TODO: headtail = True means distinguish between 1st and last samples kwargs are allowed configurations, we have to pass in dict """ colors = getColorCycle() if isinstance(arr, np.ndarray): nCat = arr.shape[0] if transpose: nFeature = arr.shape[1] else: nFeature = arr.shape[2] elif isinstance(arr, list): nCat = len(arr) for arr_ in arr: assert isinstance(arr_, np.ndarray) if transpose: nFeature = arr[0].shape[0] else: nFeature = arr[0].shape[1] # parse x for x-axis if x is not None: if isinstance(x, list): useList = True else: useList = False # get subplots nRow = int(np.floor(np.sqrt(nFeature))) if nFeature % nRow == 0: nCol = nFeature // nRow else: nCol = nFeature // nRow + 1 # create figure if ax is None: fig, axes = plt.subplots(nRow, nCol) tight = True else: axes = ax # we hope for the good tight = False for i in range(nFeature): row = i // nCol col = i % nCol try: ax = axes[row, col] except: try: ax = axes[col] except: ax = axes # plot for each one if len(kwargs) == 0 and x is None and isinstance(arr, np.ndarray): # empty dict and no x information if transpose: ax.plot(arr[:, i, :].T) else: ax.plot(arr[:, :, i].T) else: for j in range(nCat): if transpose: arr_ = arr[j][i, :] else: arr_ = arr[j][:, i] dct = dict() try: tmp = kwargs.iteritems() except: tmp = kwargs.items() for key, item in tmp: if isinstance(item, dict): if j in item: dct[key] = item[j] elif key in plotkwargs: dct[key] = item # dct = {key: item[j] for key, item in kwargs.iteritems() if j in item} if 'color' not in dct and 'c' not in dct: dct['color'] = colors[j % len(colors)] # avoid overflow if x is None: ax.plot(arr_, **dct) else: if useList: ax.plot(x[j], arr_, **dct) else: ax.plot(x, arr_, **dct) if tight: fig.tight_layout() if show: plt.show() return axes
[docs]def compareXYZ(arr, ax=None, transpose=False, d3=False, scatter=False, show=False, **kwargs): """hybrid of compare, and plot. Assume we have a cat by N by dim dataset, we want to select a few col/row to plot in 2d/3d""" colors = getColorCycle() assert isinstance(arr, np.ndarray) assert arr.ndim == 3 nCat = arr.shape[0] if transpose: alongaxis = 1 else: alongaxis = 2 # create figure if ax is None: if d3: fig, ax = get3dAxis() else: fig, ax = plt.subplots() # extract values xind = kwargs.get('xind', 0) yind = kwargs.get('xind', 1) # get values to plot tx = getIndAlongAxis(arr, alongaxis, xind) ty = getIndAlongAxis(arr, alongaxis, yind) if d3: zind = kwargs.get('zind', 2) # get which column we should focus tz = getIndAlongAxis(arr, alongaxis, zind) # now we get a bunch of cat by N matrix, we plot cat by cat for j in range(nCat): # updated dct, properties can be set in bunch mode dct = dict() for key, item in kwargs.iteritems(): if isinstance(item, dict): if j in item: dct[key] = item[j] elif key in plotkwargs: dct[key] = item # dct = {key: item[j] for key, item in kwargs.iteritems() if j in item} if 'color' not in dct and 'c' not in dct: dct['color'] = colors[j % len(colors)] # avoid overflow if d3: if scatter: ax.scatter(tx[j], ty[j], tz[j], **dct) else: ax.plot(tx[j], ty[j], tz[j], **dct) else: if scatter: ax.scatter(tx[j], ty[j], **dct) else: ax.plot(tx[j], ty[j], **dct) if show: plt.show() return ax