MeLIME explanations for a Convolutional Neural Networking (CNN) model. The model was trained using the MNIST Dataset.
import os, sys
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from melime.generators.vae_gen import VAEGen
from melime.explainers.explainer import Explainer
from melime.explainers.visualizations.visualization import ImagePlot
import torch
import torch.utils.data
from torchvision import datasets, transforms
from utils import mnist_cnn
batch_size = 128
cuda = True
kwargs = {"num_workers": 1, "pin_memory": True} if cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST("./data/", train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("./data/", train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
])),
batch_size=batch_size, **kwargs)
os.makedirs("pretrained", exist_ok=True)
path_model = "pretrained/cnn_mnist.torch"
device = torch.device("cuda" if cuda else "cpu")
device_cpu = torch.device("cpu")
if os.path.exists(path_model):
model = mnist_cnn.model_load(device, path=path_model)
else:
model = mnist_cnn.train_model(
train_loader, test_loader, device='cuda', path=path_model, epochs=20)
def model_predict(x_):
# Creation of a prediction function to use in MeLIME.
if isinstance(x_, torch.Tensor):
x_tf = x_
else:
x_tf = torch.from_numpy(x_)
with torch.no_grad():
x_tf = x_tf.reshape(-1, 1, 28, 28)
y = model(x_tf.to(device=device))
y = y.data.exp().to(device_cpu).numpy()
return y
generator = VAEGen(input_dim=784, verbose=True)
path_vaegen = "pretrained/vae_gen.melime"
if os.path.exists(path_vaegen):
generator.load_manifold(path_vaegen)
else:
generator.fit(train_loader, epochs=20)
generator.save_manifold(path_vaegen)
batch_test = next(iter(test_loader))
def get_explanation(x_explain, class_to_explain, r=0.6):
print("Explanation for: ", class_to_explain)
print("Probability:", model_predict(x_explain)[0][class_to_explain])
explain_linear = Explainer(
model_predict=model_predict, generator=generator, local_model='SGD')
explanation, contra = explain_linear.explain_instance(
x_explain=x_explain,
class_index=class_to_explain,
r=r,
n_samples=1000,
tol_importance=0.01,
tol_error=0.01,
include_x_explain_train=True
)
return explanation, contra
def plot_explanation(explanation, contra, x_explain=None, class_to_explain=0):
return ImagePlot.plot_importance_contrafactual(
explanation.explain(), contra, class_to_explain, x_explain=x_explain)
x_explain = batch_test[0][9].to(device_cpu).numpy()
y_explain = model_predict(x_explain)
y_explain_index = np.argmax(y_explain)
print('Predicted:', y_explain_index, y_explain[0][y_explain_index])
top_class = np.argsort(y_explain)[0][::-1]
print('Top 3 predicted Class', top_class[:3])
y = f'Prediction: {np.argmax(y_explain):}'
ImagePlot.plot_instances(x_explain.reshape(28,28))
index = 0
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}_mask.svg", dpi=300)
index = 1
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}_mask.svg", dpi=300)
index = 2
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_9_e_c_{top_class[index]}_mask.svg", dpi=300)
x_explain = batch_test[0][15].to(device_cpu).numpy()
y_explain = model_predict(x_explain)
y_explain_index = np.argmax(y_explain)
top_class = np.argsort(y_explain)[0][::-1]
print('Predicted:', y_explain_index, y_explain[0][y_explain_index])
print('Top 3 predicted Class', top_class[:3])
y = f'Prediction: {np.argmax(y_explain):}'
ImagePlot.plot_instances(x_explain.reshape(28,28))
index = 0
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5.svg", dpi=300)
index = 1
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_e_c_{top_class[index]}_mask.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_e_c_{top_class[index]}_mask.svg", dpi=300)
index = 2
explanation, contra = get_explanation(x_explain, top_class[index], r=1.0)
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_e_c_{top_class[index]}_mask.svg", dpi=300)
x_explain = np.loadtxt('five_example_1.np').astype(np.float32).reshape(1, 28, 28)
y_explain = model_predict(x_explain)
y_explain_index = np.argmax(y_explain)
print('Predicted:', y_explain_index, y_explain[0][y_explain_index])
top_class = np.argsort(y_explain)[0][::-1]
print('Top 3 predicted Class', top_class[:3])
y = f'Prediction: {np.argmax(y_explain):}'
ImagePlot.plot_instances(x_explain.reshape(28,28))
plt.savefig(f"MNIST_number_5_file.svg", dpi=300)
index = 0
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}_mask.svg", dpi=300)
index = 1
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}_mask.svg", dpi=300)
index = 2
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}.svg", dpi=300)
plot_explanation(explanation, contra, x_explain, top_class[index])
plt.savefig(f"MNIST_number_5_file_e_c_{top_class[index]}_mask.svg", dpi=300)
x_explain = np.loadtxt('five_example_2.np').astype(np.float32).reshape(1, 28, 28)
y_explain = model_predict(x_explain)
y_explain_index = np.argmax(y_explain)
print('predicted:', y_explain_index, y_explain[0][y_explain_index])
top_class = np.argsort(y_explain)[0][::-1]
print('Top 3 predicted Class', top_class[:3])
y = f'Prediction: {np.argmax(y_explain)}'
index = 0
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
index = 1
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
index = 2
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
x_explain = np.loadtxt('three_example.np').astype(np.float32).reshape(1, 28, 28)
y_explain = model_predict(x_explain)
y_explain_index = np.argmax(y_explain)
print('predicted:', y_explain_index, y_explain[0][y_explain_index])
top_class = np.argsort(y_explain)[0][::-1]
print('Top 3 predicted Class', top_class[:3])
y = f'Prediction: {np.argmax(y_explain):}'
index = 0
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
index = 1
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
index = 2
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
index = 3
explanation, contra = get_explanation(x_explain, top_class[index])
plot_explanation(explanation, contra, x_explain=None, class_to_explain=top_class[index])
plot_explanation(explanation, contra, x_explain, top_class[index])
Thank you!