#!/usr/bin/python
import re
import ROOT as root
import matplotlib.pyplot as mpl
import numpy
###########################################
# input filename
filename = "TMVAReg.root"
# regression target variable branchname name
# in TestTree and TrainTree
targetVariableName = "fvalue"
# Variables to test for overtraining,
# and not plot in the comparison of input variables
MVAVariableNames = ["FDA_GA","KNN","LD","PDEFoam"]
# Variables to not plot in comparisons of input variables
dontPlotVariables = ["className","classID"]
# Title of target variable to include on x-axis of
# input variable plots
targetVariableTitle = "Regression Target"
# Cuts to apply before analysis
cuts = "var1 > 0.5"
###########################################
f1 = root.TFile(filename)
root.gROOT.SetBatch(1)
root.gStyle.SetOptStat(0)
tree = f1.Get("TestTree")
treeTrain = f1.Get("TrainTree")
tmvaTools = root.TMVA.Tools.Instance()
mutualInfoList=[]
mutualInfoOvertrainList=[]
mutualInfoOvertrainNameList=[]
meanOvertrainList=[]
rmsOvertrainList=[]
corrRatioList=[]
corrList=[]
nameList=[]
for branch,branchTrain in zip(tree.GetListOfBranches(),treeTrain.GetListOfBranches()):
if branch.GetName() != branchTrain.GetName():
print("Warning: different branches from TestTree and TrainTree")
continue
if dontPlotVariables.count(branch.GetName())>0:
continue
histName = branch.GetName() + "HistForCorr"
drawString = targetVariableName+":"+branch.GetName()+">>"+histName
tree.Draw(drawString,cuts)
myHist = root.gDirectory.Get(histName)
corrRatio = tmvaTools.GetCorrelationRatio(myHist)
mutualInfo = tmvaTools.GetMutualInformation(myHist)
corr = myHist.GetCorrelationFactor()
if MVAVariableNames.count(branch.GetName())>0:
histNameTrain = branchTrain.GetName() + "HistForCorrTrain"
drawStringTrain = targetVariableName+":"+branchTrain.GetName()+">>"+histNameTrain
treeTrain.Draw(drawStringTrain,cuts)
myHistTrain = root.gDirectory.Get(histNameTrain)
mutualInfoTrain = tmvaTools.GetMutualInformation(myHistTrain)
if(mutualInfoTrain>0.0):
mutualInfoOvertrainList.append(mutualInfoTrain/mutualInfo-1.0)
else:
mutualInfoOvertrainList.append(10000000)
mutualInfoOvertrainNameList.append(branch.GetName())
## Getting Offset, RMS
histNameTrain = branchTrain.GetName() + "HistForRMSTrain"
drawStringTrain = branchTrain.GetName()+"-"+targetVariableName+">>"+histNameTrain
treeTrain.Draw(drawStringTrain)
myHistTrain = root.gDirectory.Get(histNameTrain)
meanTrain = myHistTrain.GetMean()
rmsTrain = myHistTrain.GetRMS()
histName = branch.GetName() + "HistForRMS"
drawString = branch.GetName()+"-"+targetVariableName+">>"+histName
tree.Draw(drawString)
myHist = root.gDirectory.Get(histName)
mean = myHist.GetMean()
rms = myHist.GetRMS()
meanOvertrainList.append(meanTrain/mean)
rmsOvertrainList.append(rms/rmsTrain-1.0)
else:
mutualInfoList.append(mutualInfo)
corrRatioList.append(corrRatio)
corrList.append(corr)
nameList.append(branch.GetName())
##########################
## Now make plots
fig = mpl.figure()
ax1 = fig.add_subplot(111)
ax1bounds = ax1.get_position().bounds
#ax1.set_position([0.25,0.1,0.7,0.85]) #uncomment if you need more space for names
pos = numpy.arange(len(mutualInfoList))
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Mututal Information with "+targetVariableTitle)
bars = ax1.barh(pos,mutualInfoList, 0.5)
#fig.show()
fig.savefig("mutualInfo.png")
ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Mututal Information with "+targetVariableTitle)
bars2 = ax1.barh(pos,mutualInfoList, 0.5,log=True)
#fig.show()
fig.savefig("mutualInfoLog.png")
ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Correlation Ratio with "+targetVariableTitle)
bars2 = ax1.barh(pos,corrRatioList, 0.5)
#fig.show()
fig.savefig("correlationRatio.png")
ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Correlation with "+targetVariableTitle)
bars2 = ax1.barh(pos,corrList, 0.5)
#fig.show()
fig.savefig("correlation.png")
ax1.cla()
ax1.grid(axis="x")
pos = numpy.arange(len(mutualInfoOvertrainList))
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($I_{train}/I_{test}-1$)")
#ax1.set_title("Amount of Over Training Estimated From Mutual Information")
bars2 = ax1.barh(pos,mutualInfoOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainMI.png")
ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($RMS_{test}/RMS_{train}-1$)")
bars2 = ax1.barh(pos,rmsOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainRMS.png")
ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($<Offset>_{test}/<Offset>_{train}$)")
bars2 = ax1.barh(pos,meanOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainMean.png")