Note
Go to the end to download the full example code
utils
import os
import random
from typing import Union, List, Tuple, Any
from collections.abc import KeysView, ValuesView
import shap
from shap.plots import scatter as sh_scatter
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import KFold
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from ai4water.functional import Model as f_model
from ai4water import Model
from ai4water.models import MLP
from ai4water.postprocessing import prediction_distribution_plot
from ai4water.preprocessing import DataSet
from ai4water.utils.utils import dateandtime_now
from SeqMetrics import RegressionMetrics
from easy_mpl import violin_plot, scatter, boxplot
from easy_mpl.utils import is_rgb
from easy_mpl.utils import BAR_CMAPS
from easy_mpl.utils import process_axes
from easy_mpl.utils import create_subplots
from easy_mpl.utils import to_1d_array, make_cols_from_cmap
ADSORBENT_TYPES = {
"GIC": "GB",
"Exfoliated GIC": "GB",
"PAC": "AC",
"APAC": "AC",
"CS": "AC",
"AC600": "AC",
"AC700": "AC",
"AC800": "AC",
"AC900": "AC",
"CMCAC": "AC",
"CS-AC-KOH": "AC",
"CS-AC-NaOH": "AC",
"CS-AC-H3PO4": "AC",
"CS-AC-H4P2O7": "AC",
"TSAC": "AC",
"MC350": "AC",
"MC400": "AC",
"MC450": "AC",
"MC500": "AC",
"MC550": "AC",
"MC600": "AC",
"MC0.75": "AC",
"MC0.659": "AC",
"MC0.569": "AC",
"MC0.478": "AC",
"MC20/1": "AC",
"MC25/1": "AC",
"MC30/1": "AC",
"MC35/1": "AC",
"MCNaOH10": "AC",
"MCNaOH30": "AC",
"MCNaOH40": "AC",
"MCNaOH50": "AC",
"GSAC": "AC",
"CAS": "AC",
"SAC": "AC",
"HAC": "AC",
"CAC": "AC",
"CBAC": "AC",
"VAC": "AC",
"TRAC": "AC",
"BGBHAC": "AC",
"GSAC-Ce-1": "AC",
"TWAC": "AC",
"WSAC": "AC",
"PSB": "Biochar",
"PSB-LDHMgAl": "Biochar",
"RH Biochar": "Biochar",
"M-Biochar": "Biochar",
"MN-Biochar": "Biochar",
"MZ-Biochar": "Biochar",
}
DYE_TYPES = {
'CR': 'Anionic',
'FG FCF': 'Anionic',
'MO': 'Anionic', 'NR': 'Anionic',
'AR': 'Anionic', 'RB5': 'Anionic', 'RD': 'Anionic',
'AB25': 'Anionic',
'BV14': 'Cationic',
'MB': 'Cationic',
'SYF': 'Cationic', 'MV': 'Cationic',
'GR': 'Cationic',
'Rhd B': 'Cationic', 'YD': 'Cationic',
'AM': 'Cationic'
}
def _ohe_column(df:pd.DataFrame, col_name:str)->tuple:
# function for OHE
assert isinstance(col_name, str)
# setting sparse to True will return a scipy.sparse.csr.csr_matrix
# not a numpy array
encoder = OneHotEncoder(sparse=False)
ohe_cat = encoder.fit_transform(df[col_name].values.reshape(-1, 1))
cols_added = [f"{col_name}_{i}" for i in range(ohe_cat.shape[-1])]
df[cols_added] = ohe_cat
df.pop(col_name)
return df, cols_added, encoder
def _load_data(input_features:list=None)->pd.DataFrame:
# read excel
# our data is on the first sheet of both files
dirname = os.path.dirname(__file__)
ads_data = pd.read_excel(os.path.join(dirname, 'Adsorption and regeneration data_1007c.xlsx'))
dye_data = pd.read_excel(os.path.join(dirname, 'Dyes data.xlsx'))
# dropping unnecessary columns
ads_data = ads_data.drop(columns=['final concentation', 'Volume (mL)',
'Unnamed: 16', 'Unnamed: 17',
'Unnamed: 18', 'Unnamed: 19',
'Unnamed: 20', 'Unnamed: 21',
'Unnamed: 22', 'Unnamed: 23',
'Particle size'
])
dye_data = dye_data.drop(columns=['C', 'H', 'O', 'N', 'Ash', 'H/C', 'O/C',
'N/C', 'Average pore size',
'rpm', 'g/L', 'Ion Concentration (M)',
'Humic acid', 'wastewater type',
'Adsorption type', 'Cf', 'Ref'
])
# merging data
data = pd.concat([ads_data, dye_data])
data = data.dropna()
#removing original index of both dataframes and assigning a new index
data = data.reset_index(drop=True)
data.columns = ['Adsorption Time (min)', 'Adsorbent', 'Pyrolysis Temperature',
'Pyrolysis Time (min)', 'Dye', 'Initial Concentration', 'Solution pH',
'Adsorbent Loading', 'Volume (L)', 'Adsorption Temperature',
'Surface Area', 'Pore Volume', 'Adsorption']
data['Adsorbent'] = data.pop('Adsorbent')
# replacing a string 'Fast Green FCF' in features Dye with 'FG FCF' because it will
# cause the scatter plot in SHAP to elongate.
dye = data.pop('Dye')
data['Dye'] = dye.str.replace('Fast Green FCF', 'FG FCF')
data['Adsorption'] = data.pop('Adsorption')
target = ['Adsorption']
if input_features is None:
input_features = data.columns.tolist()[0:-1]
else:
assert isinstance(input_features, list)
assert all([feature in data.columns for feature in input_features])
return data[input_features + target]
def make_data(
input_features:list = None,
encoding:str = None
)->Tuple[pd.DataFrame, Any, Any]:
"""
prepares data for adsorption capacity prediction.
Parameters
----------
input_features : list
names of variables to use as input. By default the following features
are used as input features
- `Adsorption Time (min)`
- `Adsorbent`
- `Pyrolysis Temperature`
- `Pyrolysis Time (min)`
- `Dye`
- `Initial Concentration`
- `Solution pH`
- `Adsorbent Loading`
- `Volume (L)`
- `Adsorption Temperature`
- `Surface Area`
- `Pore Volume`
encoding : str (default=None)
whether to one hot encode the categorical variables or not
Returns
-------
data : pd.DataFrame
a pandas dataframe whose first 10 columns are numerical features
and next columns contain categorical features. The last column is
the target feature. If encoding is 'ohe' the returned
dataframe has 75 columns. 0-10 numerical features, 11-58 adsorbents
59-74: dyes 75th: target. If encode is False, then the returned
dataframe will have 13 columns.
Examples
--------
>>> data, ae, de = make_data(encoding="ohe")
>>> data.shape
(1514, 75)
>>> len(ae.categories_[0])
48
to get the original adsorbent values we can do as below
>>> ae.inverse_transform(data.iloc[:, 10:58].values)
>>> len(de.categories_[0])
16
We can also convert the one hot encoded dye columns into original/string form as
>>> de.inverse_transform(data.iloc[:, 58:-1].values)
If we don't want to encode categorical features, we can set encode to False
>>> data, _, _ = make_data()
>>> data.shape
(1514, 13)
"""
data = _load_data(input_features)
adsorbent_encoder, dye_encoder = None, None
if encoding=="ohe":
# applying One Hot Encoding
data, _, adsorbent_encoder = _ohe_column(data, 'Adsorbent')
data, _, dye_encoder = _ohe_column(data, 'Dye')
elif encoding == "le":
# applying Label Encoding
data, adsorbent_encoder = le_column(data, 'Adsorbent')
data, dye_encoder = le_column(data, 'Dye')
# moving target to last
target = data.pop('Adsorption')
data['Adsorption'] = target
return data, adsorbent_encoder, dye_encoder
def le_column(df:pd.DataFrame, col_name)->tuple:
"""label encode a column in dataframe"""
encoder = LabelEncoder()
df[col_name] = encoder.fit_transform(df[col_name])
return df, encoder
def get_dataset(encoding="ohe"):
data, adsorbent_encoder, dye_encoder = make_data(encoding=encoding)
dataset = DataSet(data=data,
seed=1575,
val_fraction=0.0,
split_random=True,
input_features=data.columns.tolist()[0:-1],
output_features=data.columns.tolist()[-1:],
)
return dataset, adsorbent_encoder, dye_encoder
def make_path():
path = os.path.join(os.getcwd(), 'results', f'mlp_{dateandtime_now()}')
os.makedirs(path)
return path
def get_fitted_model(return_path=False,
model_type=None,
from_config=True):
dataset, _, _ = get_dataset()
X_train, y_train = dataset.training_data()
if from_config:
path = os.path.join(os.path.dirname(__file__), 'results', 'mlp_20221217_213202')
cpath = os.path.join(path, 'config.json')
if model_type == 'functional':
model = f_model.from_config_file(config_path=cpath)
else:
model = Model.from_config_file(config_path=cpath)
wpath = os.path.join(path, 'weights_585_1982.99475.hdf5')
model.update_weights(wpath)
fpath = os.path.join(path, 'losses.csv')
df = pd.read_csv(fpath)[['loss', 'val_loss']]
class History(object):
def init(self):
self.history = df.to_dict()
h = History()
else:
path = make_path()
if model_type=='functional':
model = f_model(
model=MLP(units=99, num_layers=4,
activation='relu'),
lr=0.006440897421063212,
input_features=dataset.input_features,
output_features=dataset.output_features,
epochs=400, batch_size=48,
verbosity=0
)
else:
model = Model(
model=MLP(units=99, num_layers=4,
activation='relu'),
lr=0.006440897421063212,
input_features=dataset.input_features,
output_features=dataset.output_features,
epochs=400, batch_size=48,
verbosity=0
)
h = model.fit(X_train, y_train)
if return_path:
return model, path, h
return model, h
def confidenc_interval(model, X_train, y_train, X_test, y_test, alpha,
n_splits=5):
def generate_results_dataset(preds, _ci):
_df = pd.DataFrame()
_df['prediction'] = preds
if _ci >= 0:
_df['upper'] = preds + _ci
_df['lower'] = preds - _ci
else:
_df['upper'] = preds - _ci
_df['lower'] = preds + _ci
return _df
path = make_path()
model.fit(X_train, y_train, batch_size=48, verbose=0,
validation_data=(X_test, y_test),
epochs=400)
residuals = y_train - model.predict(X_train)
ci = np.quantile(residuals, 1 - alpha)
preds = model.predict(X_test)
df = generate_results_dataset(preds.reshape(-1, ), ci)
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
res = []
estimators = []
for train_index, test_index in kf.split(X_train):
X_train_, X_test_ = X_train[train_index], X_train[test_index]
y_train_, y_test_ = y_train[train_index], y_train[test_index]
path = make_path()
model.fit(X_train_, y_train_, validation_data=(X_test_, y_test_),
verbose=0, batch_size=48, epochs=400)
estimators.append(model)
_pred = model.predict(X_test_)
res.extend(list(y_test_ - _pred.reshape(-1, )))
y_pred_multi = np.column_stack([e.predict(X_test) for e in estimators])
ci = np.quantile(res, 1 - alpha)
top = []
bottom = []
for i in range(y_pred_multi.shape[0]):
if ci > 0:
top.append(np.quantile(y_pred_multi[i] + ci, 1 - alpha))
bottom.append(np.quantile(y_pred_multi[i] - ci, 1 - alpha))
else:
top.append(np.quantile(y_pred_multi[i] - ci, 1 - alpha))
bottom.append(np.quantile(y_pred_multi[i] + ci, 1 - alpha))
preds = np.median(y_pred_multi, axis=1)
df = pd.DataFrame()
df['pred'] = preds
df['upper'] = top
df['lower'] = bottom
return df
def plot_ci(df, alpha):
# plots the confidence interval
fig, ax = plt.subplots(figsize=(6, 3))
ax.fill_between(np.arange(len(df)), df['upper'], df['lower'], alpha=0.5, color='C1')
p1 = ax.plot(df['pred'], color="C1", label="Prediction")
p2 = ax.fill(np.NaN, np.NaN, color="C1", alpha=0.5)
percent = int((1 - alpha) * 100)
ax.legend([(p2[0], p1[0]), ], [f'{percent}% Confidence Interval'],
fontsize=12)
ax.set_xlabel("Test Samples", fontsize=12)
ax.set_ylabel("Adsorption Capacity", fontsize=12)
return ax
def evaluate_model(true, predicted):
metrics = RegressionMetrics(true, predicted)
for i in ['mse', 'rmse', 'r2', 'r2_score', 'mape', 'mae']:
print(i, getattr(metrics, i)())
return
def prediction_distribution(
feature_name,
test_p,
cut,
grid=None,
plot_type="violin",
show:bool = True,
):
dataset, _, _ = get_dataset()
X_test, _ = dataset.test_data()
ax, df = prediction_distribution_plot(
mode='regression',
inputs=pd.DataFrame(X_test, columns=dataset.input_features),
prediction=test_p,
feature=feature_name,
feature_name=feature_name,
show=False,
cust_grid_points=grid
)
if plot_type == "bar":
if show:
plt.show()
return ax
if feature_name == 'Volume (L)':
df.drop(1, inplace=True)
df['display_column'] = ['[0.02,0.04)', '[0.04,0.05)', '[0.05,0.1)', '[0.1,0.25)', '[0.25,1)']
preds = {}
for interval in df['display_column']:
st, en = interval.split(',')
st = float(''.join(e for e in st if e not in ["]", ")", "[", "("]))
en = float(''.join(e for e in en if e not in ["]", ")", "[", "("]))
df1 = pd.DataFrame(X_test, columns=dataset.input_features)
df1['target'] = test_p
df1 = df1[[feature_name, 'target']]
df1 = df1[(df1[feature_name] >= st) & (df1[feature_name] < en)]
preds[interval] = df1['target'].values
for k, v in preds.items():
assert len(v) > 0, f"{k} has no values in it"
plt.close('all')
if plot_type == "violin":
ax = violin_plot(list(preds.values()), cut=cut, show=False)
ax.set_xticks(range(len(preds)))
ax.set_facecolor("#fbf9f4")
else:
ax, _ = boxplot(
list(preds.values()), show=False,
fill_color="lightpink", patch_artist=True,
medianprops={"color": "black"}, flierprops={"ms": 1.0})
ax.set_xticklabels(list(preds.keys()), size=12, weight='bold')
ax.set_yticklabels(ax.get_yticks().astype(int), size=12, weight='bold')
ax.set_title(feature_name, size=14, fontweight="bold")
if show:
plt.show()
return ax
def box_violin(ax, data, palette=None):
if palette is None:
palette = sns.cubehelix_palette(start=.5, rot=-.5, dark=0.3, light=0.7)
ax = sns.violinplot(orient='h', data=data,
palette=palette,
scale="width", inner=None,
ax=ax)
xlim = ax.get_xlim()
ylim = ax.get_ylim()
for violin in ax.collections:
bbox = violin.get_paths()[0].get_extents()
x0, y0, width, height = bbox.bounds
violin.set_clip_path(plt.Rectangle((x0, y0), width, height / 2, transform=ax.transData))
sns.boxplot(orient='h', data=data, saturation=1, showfliers=False,
width=0.3, boxprops={'zorder': 3, 'facecolor': 'none'}, ax=ax)
old_len_collections = len(ax.collections)
for dots in ax.collections[old_len_collections:]:
dots.set_offsets(dots.get_offsets() + np.array([0, 0.12]))
ax.set_xlim(xlim)
ax.set_ylim(ylim)
return
def shap_interaction_all(shap_values_exp, feature, feature_names, CAT_FEATURES):
inds = shap.utils.potential_interactions(shap_values_exp[:, feature], shap_values_exp)
# make plots colored by each of the top three possible interacting features
n, n_plots = 0, 0
while n <= len(feature_names):
if shap_values_exp.feature_names[inds[n]] not in CAT_FEATURES:
sh_scatter(shap_values_exp[:, feature], show=False,
color=shap_values_exp[:, inds[n]],
)
plt.tight_layout()
plt.show()
n_plots += 1
if n_plots >= 10:
break
n += 1
return
def shap_scatter(
shap_values, # SHAP values for a single feature
feature_wrt:pd.Series = None,
show_hist=True,
show=True,
is_categorical=False,
palette_name = "tab10",
s = 70,
edgecolors='black',
linewidth=0.8,
alpha=0.8,
ax = None,
**scatter_kws
):
if ax is None:
fig, ax = plt.subplots()
if feature_wrt is None:
c = None
else:
if is_categorical:
if isinstance(palette_name, (tuple, list)):
assert len(palette_name) == len(feature_wrt.unique())
rgb_values = palette_name
else:
rgb_values = sns.color_palette(palette_name, feature_wrt.unique().__len__())
color_map = dict(zip(feature_wrt.unique(), rgb_values))
c= feature_wrt.map(color_map)
else:
c= feature_wrt.values.reshape(-1,)
_, pc = scatter(
shap_values.data,
shap_values.values,
c=c,
s=s,
marker="o",
edgecolors=edgecolors,
linewidth=linewidth,
alpha=alpha,
ax=ax,
show=False,
**scatter_kws
)
if feature_wrt is not None:
feature_wrt_name = ' '.join(feature_wrt.name.split('_'))
if is_categorical:
# add a legend
handles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=v,
label=k, markersize=8) for k, v in color_map.items()]
ax.legend(title=feature_wrt_name,
handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left',
title_fontsize=14
)
else:
cbar = plt.colorbar(pc, aspect=80)
cbar.ax.set_ylabel(feature_wrt_name, rotation=90, labelpad=14,
fontsize=14, weight="bold")
set_ticks(cbar.ax, "y")
cbar.set_alpha(1)
cbar.outline.set_visible(False)
feature_name = ' '.join(shap_values.feature_names.split('_'))
ax.set_xlabel(feature_name, fontsize=14, weight="bold")
ax.set_ylabel(f"SHAP value for {feature_name}", fontsize=14, weight="bold")
ax.axhline(0, color='grey', linewidth=1.3, alpha=0.3, linestyle='--')
set_ticks(ax)
set_ticks(ax, "y")
if show_hist:
x = shap_values.data
if len(x) >= 500:
bin_edges = 50
elif len(x) >= 200:
bin_edges = 20
elif len(x) >= 100:
bin_edges = 10
else:
bin_edges = 5
ax2 = ax.twinx()
xlim = ax.get_xlim()
ax2.hist(x.reshape(-1,), bin_edges,
range=(xlim[0], xlim[1]),
density=False, facecolor='#000000', alpha=0.1, zorder=-1)
ax2.set_ylim(0, len(x))
ax2.set_yticks([])
if show:
plt.show()
return ax
def set_ticks(axes:plt.Axes, which="x", size=12):
ticks = getattr(axes, f"get_{which}ticks")()
ticks = np.array(ticks)
if 'float' in ticks.dtype.name:
ticks = np.round(ticks, 2)
else:
ticks = ticks.astype(int)
getattr(axes, f"set_{which}ticklabels")(ticks, size=size, weight="bold")
return
def _jitter_data(data, x_jitter, seed=None):
s = np.random.RandomState(seed)
s.random_sample([1, 2, 3])
if x_jitter > 0:
if x_jitter > 1: x_jitter = 1
xvals = data.copy()
if isinstance(xvals[0], float):
xvals = xvals.astype(np.float)
xvals = xvals[~np.isnan(xvals)]
xvals = np.unique(xvals) # returns a sorted array
if len(xvals) >= 2:
smallest_diff = np.min(np.diff(xvals))
jitter_amount = x_jitter * smallest_diff
data += (s.random_sample(size = len(data))*jitter_amount) - (jitter_amount/2)
return data
def bar_chart(
values,
labels=None,
orient:str = 'h',
sort:bool = False,
max_bars:int = None,
errors = None,
color=None,
cmap: Union[str, List[str]] = None,
rotation:int = 0,
bar_labels: Union[list, np.ndarray] = None,
bar_label_kws=None,
share_axes: bool = True,
width = None,
ax:plt.Axes = None,
ax_kws: dict = None,
show:bool = True,
**kwargs
) -> Union[plt.Axes, List[plt.Axes]]:
if labels is None:
if hasattr(values, "index") and hasattr("values", "name"):
labels = values.index
naxes = 1
ncharts = 1
if is_1d(values):
values = to_1d_array(values)
else:
values = np.array(values)
ncharts = values.shape[1]
if share_axes:
kwargs['edgecolor'] = kwargs.get('edgecolor', 'k')
else:
naxes = values.shape[1]
colors = get_color(cmap, color, ncharts, len(values))
figsize = None
if 'figsize' in kwargs:
figsize = kwargs.pop('figsize')
ax = maybe_create_axes(ax, naxes, figsize=figsize)
if ncharts == 1:
values, labels, bar_labels, colors = preprocess(values, labels,
bar_labels, sort, max_bars, colors[0])
ind = np.arange(len(values))
bar_on_axes(ax[0], orient=orient, ax_kws=ax_kws, ind=ind,
values=values,
width=width, ticks=ind, labels=labels, color=colors,
bar_labels=bar_labels,
rotation=rotation, errors=errors,
bar_label_kws=bar_label_kws, kwargs=kwargs)
elif share_axes:
ind = np.arange(len(values)) # the label locations
width = width or 1/ncharts * 0.9 # the width of the bars
inds = []
for idx in range(ncharts):
if idx>0:
ind = ind + width
inds.append(ind)
inds = np.column_stack(inds)
ticks = np.mean(inds, axis=1)
for idx in range(ncharts):
_kwargs =kwargs.copy()
_kwargs['label'] = _kwargs.get('label', idx)
vals, labels, bar_labels, color = preprocess(values[:, idx], labels,
bar_labels, sort, max_bars, colors[idx])
bar_on_axes(ax[0], orient, ax_kws,
inds[:, idx], vals, width, ticks, labels,
color, bar_labels,
rotation, errors, bar_label_kws, _kwargs)
else:
for idx in range(naxes):
axes = ax[idx]
data = values[:, idx]
data, labels, bar_labels, color = preprocess(data, labels, bar_labels,
sort, max_bars, colors[idx])
_kwargs = kwargs.copy()
_kwargs['label'] = _kwargs.get('label', idx)
ind = np.arange(len(data))
bar_on_axes(axes, orient, ax_kws,
ind, data, width, ind, labels,
color, bar_labels,
rotation, errors,
bar_label_kws=bar_label_kws, kwargs=_kwargs)
if show:
plt.show()
if len(ax) == 1:
ax = ax[0]
return ax
def maybe_create_axes(ax, naxes:int, figsize=None)->List[plt.Axes]:
if ax is None:
ax = plt.gca()
if naxes>1:
f, ax = create_subplots(ax=ax, naxes=naxes, figsize=figsize)
ax = ax.flatten()
else:
if figsize:
ax.figure.set_size_inches(figsize)
ax = [ax]
elif naxes>1:
f, ax = create_subplots(ax=ax, naxes=naxes, figsize=figsize)
ax = ax.flatten()
else:
if figsize:
ax.figure.set_size_inches(figsize)
ax = [ax]
return ax
def handle_sort(sort, values, labels, bar_labels, color):
if sort:
sort_idx = np.argsort(values)
values = values[sort_idx]
labels = np.array(labels)[sort_idx]
if bar_labels is not None:
bar_labels = np.array(bar_labels)
bar_labels = bar_labels[sort_idx]
if 'float' in bar_labels.dtype.name:
bar_labels = np.round(bar_labels, decimals=2)
if isinstance(color, (list, np.ndarray, tuple)):
if is_rgb(color[0]) or isinstance(color[0], str):
color = np.array(color)[sort_idx]
return values, labels, bar_labels, color
def handle_maxbars(max_bars, values, labels):
if max_bars:
n = len(values) - max_bars
last_val = sum(values[0:-max_bars])
values = values[-max_bars:]
labels = labels[-max_bars:]
values = np.append(last_val, values)
labels = np.append(f"Rest of {n}", labels)
return values, labels
def preprocess(values, labels, bar_labels, sort, max_bars, colors):
if labels is None:
labels = [f"F{i}" for i in range(len(values))]
values, labels, bar_labels, colors = handle_sort(sort, values, labels, bar_labels, colors)
values, labels = handle_maxbars(max_bars, values, labels)
return values, labels, bar_labels, colors
def bar_on_axes(ax, orient, ax_kws, *args, **kwargs):
if orient in ['h', 'horizontal']:
horizontal_bar(ax, *args, **kwargs)
else:
vertical_bar(ax, *args, **kwargs)
if ax_kws:
process_axes(ax, **ax_kws)
return
def horizontal_bar(ax, ind, values, width, ticks, labels, color, bar_labels,
rotation, errors, bar_label_kws, kwargs):
if width:
bar = ax.barh(ind, values, width, color=color, **kwargs)
else:
bar = ax.barh(ind, values, color=color, **kwargs)
ax.set_yticks(ticks)
ax.set_yticklabels(labels, rotation=rotation)
set_bar_labels(bar, ax, bar_labels, bar_label_kws, errors,
values, ind)
if 'label' in kwargs:
ax.legend()
return
def vertical_bar(ax, ind, values, width, ticks, labels, color, bar_labels,
rotation, errors, bar_label_kws, kwargs):
bar = ax.bar(ind, values, width=width or 0.8, color=color, **kwargs)
ax.set_xticks(ticks)
ax.set_xticklabels(labels, rotation=rotation)
set_bar_labels(bar, ax, bar_labels, bar_label_kws, errors,
ind, values)
return
def set_bar_labels(bar, ax, bar_labels, bar_label_kws, errors,
values, ind):
if bar_labels is not None:
bar_label_kws = bar_label_kws or {'label_type': 'center'}
if hasattr(ax, 'bar_label'):
ax.bar_label(bar, labels=bar_labels, **bar_label_kws)
else:
bar.set_label(bar_labels)
if errors is not None:
ax.errorbar(values, ind, xerr=errors, fmt=".",
color="black")
return
def is_1d(array):
if isinstance(array, (KeysView, ValuesView)):
array = np.array(list(array))
else:
array = np.array(array)
if len(array)==array.size:
return True
return False
def get_color(cmap, color, ncharts, n_bars)->list:
if not isinstance(cmap, list):
cmap = [cmap for _ in range(ncharts)]
if not isinstance(color, list):
color = [color for _ in range(ncharts)]
elif ncharts == 1:
# the user has specified separate color for each bar
# in next for loop we don't want to get just firs color from the list
color = [color]
colors = []
for idx in range(ncharts):
cm = make_cols_from_cmap(cmap[idx] or random.choice(BAR_CMAPS), n_bars, 0.2)
clr = color[idx] if color[idx] is not None else cm
colors.append(clr)
return colors
Total running time of the script: (0 minutes 0.025 seconds)