Hugonweb | TMVA Plotter

#!/usr/bin/python

import re
import ROOT as root
import matplotlib.pyplot as mpl
import numpy

###########################################

# input filename
filename = "TMVAReg.root"

# regression target variable branchname name 
# in TestTree and TrainTree
targetVariableName = "fvalue"

# Variables to test for overtraining, 
# and not plot in the comparison of input variables
MVAVariableNames = ["FDA_GA","KNN","LD","PDEFoam"]

# Variables to not plot in comparisons of input variables
dontPlotVariables = ["className","classID"]

# Title of target variable to include on x-axis of 
# input variable plots
targetVariableTitle = "Regression Target"

# Cuts to apply before analysis
cuts = "var1 > 0.5"

###########################################

f1 = root.TFile(filename)
root.gROOT.SetBatch(1)
root.gStyle.SetOptStat(0)

tree = f1.Get("TestTree")
treeTrain = f1.Get("TrainTree")
tmvaTools = root.TMVA.Tools.Instance()
mutualInfoList=[]
mutualInfoOvertrainList=[]
mutualInfoOvertrainNameList=[]
meanOvertrainList=[]
rmsOvertrainList=[]
corrRatioList=[]
corrList=[]
nameList=[]
for branch,branchTrain in zip(tree.GetListOfBranches(),treeTrain.GetListOfBranches()):
  if branch.GetName() != branchTrain.GetName():
        print("Warning: different branches from TestTree and TrainTree")
        continue
  if dontPlotVariables.count(branch.GetName())>0:
      continue
  histName = branch.GetName() + "HistForCorr"
  drawString = targetVariableName+":"+branch.GetName()+">>"+histName
  tree.Draw(drawString,cuts)
  myHist = root.gDirectory.Get(histName)
  corrRatio = tmvaTools.GetCorrelationRatio(myHist)
  mutualInfo = tmvaTools.GetMutualInformation(myHist)
  corr = myHist.GetCorrelationFactor()
  if MVAVariableNames.count(branch.GetName())>0:
      histNameTrain = branchTrain.GetName() + "HistForCorrTrain"
      drawStringTrain = targetVariableName+":"+branchTrain.GetName()+">>"+histNameTrain
      treeTrain.Draw(drawStringTrain,cuts)
      myHistTrain = root.gDirectory.Get(histNameTrain)
      mutualInfoTrain = tmvaTools.GetMutualInformation(myHistTrain)
      if(mutualInfoTrain>0.0):
        mutualInfoOvertrainList.append(mutualInfoTrain/mutualInfo-1.0)
      else:
        mutualInfoOvertrainList.append(10000000)
      mutualInfoOvertrainNameList.append(branch.GetName())
      ## Getting Offset, RMS
      histNameTrain = branchTrain.GetName() + "HistForRMSTrain"
      drawStringTrain = branchTrain.GetName()+"-"+targetVariableName+">>"+histNameTrain
      treeTrain.Draw(drawStringTrain)
      myHistTrain = root.gDirectory.Get(histNameTrain)
      meanTrain = myHistTrain.GetMean()
      rmsTrain = myHistTrain.GetRMS()
      histName = branch.GetName() + "HistForRMS"
      drawString = branch.GetName()+"-"+targetVariableName+">>"+histName
      tree.Draw(drawString)
      myHist = root.gDirectory.Get(histName)
      mean = myHist.GetMean()
      rms = myHist.GetRMS()
      meanOvertrainList.append(meanTrain/mean)
      rmsOvertrainList.append(rms/rmsTrain-1.0)
  else:
      mutualInfoList.append(mutualInfo)
      corrRatioList.append(corrRatio)
      corrList.append(corr)
      nameList.append(branch.GetName())

##########################
## Now make plots

fig = mpl.figure()
ax1 = fig.add_subplot(111)
ax1bounds = ax1.get_position().bounds
#ax1.set_position([0.25,0.1,0.7,0.85]) #uncomment if you need more space for names
pos = numpy.arange(len(mutualInfoList))
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Mututal Information with "+targetVariableTitle)
bars = ax1.barh(pos,mutualInfoList, 0.5)
#fig.show()
fig.savefig("mutualInfo.png")

ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Mututal Information with "+targetVariableTitle)
bars2 = ax1.barh(pos,mutualInfoList, 0.5,log=True)
#fig.show()
fig.savefig("mutualInfoLog.png")

ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Correlation Ratio with "+targetVariableTitle)
bars2 = ax1.barh(pos,corrRatioList, 0.5)
#fig.show()
fig.savefig("correlationRatio.png")

ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(nameList))
ax1.set_xlabel("Correlation with "+targetVariableTitle)
bars2 = ax1.barh(pos,corrList, 0.5)
#fig.show()
fig.savefig("correlation.png")

ax1.cla()
ax1.grid(axis="x")
pos = numpy.arange(len(mutualInfoOvertrainList))
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($I_{train}/I_{test}-1$)")
#ax1.set_title("Amount of Over Training Estimated From Mutual Information")
bars2 = ax1.barh(pos,mutualInfoOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainMI.png")

ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($RMS_{test}/RMS_{train}-1$)")
bars2 = ax1.barh(pos,rmsOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainRMS.png")

ax1.cla()
ax1.grid(axis="x")
ax1.set_yticks(pos+0.25)
ax1.set_yticklabels(tuple(mutualInfoOvertrainNameList))
ax1.set_xlabel("Amount of Over Training ($<Offset>_{test}/<Offset>_{train}$)")
bars2 = ax1.barh(pos,meanOvertrainList, 0.5)
#fig.show()
fig.savefig("overtrainMean.png")