Loss Analysis

Loss Analysis#

function#

The DeePTB contains a module to help the user better understand the details of the error of the E3TB module. We decompose the error of E3TB model into several parts:

  • onsite blocks: for diagonal blocks of the predicted quantum tensors the onsite blocks are further arranged according to the atom species.

  • hopping blocks: for off-diagonal blocks, the hopping block errors are then further arranged according to the atom-pair types.

usage#

For using this function, we need a dataset and the model. Just build them up in advance.

from dptb.data import build_dataset
from dptb.nn import build_model

dataset = build_dataset(
    root="your dataset root",
    type="DefaultDataset",
    prefix="frame",
    get_overlap=True,
    get_Hamiltonian=True,
    basis={"Si":"2s2p1d"}
    )

model = build_model("./ovp/checkpoint/nnenv.best.pth", common_options={"device":"cuda"})
model.eval()

Then, the user should sample over the dataset using the dataloader and doing a analysis with running average, the code looks like:

import torch
from dptb.nnops.loss import HamilLossAnalysis
from dptb.data.dataloader import DataLoader
from tqdm import tqdm
from dptb.data import AtomicData

ana = HamilLossAnalysis(idp=model.idp, device=model.device, decompose=True, overlap=True)

loader = DataLoader(dataset, batch_size=10, shuffle=False, num_workers=0)

for data in tqdm(loader, desc="doing error analysis"):
    with torch.no_grad():
        ref_data = AtomicData.to_AtomicDataDict(data.to("cuda"))
        data = model(ref_data)
        ana(data, ref_data, running_avg=True)

The analysis results are stored in ana.stats, which is a dictionary of statistics. The user can check the value directly, or display the results by:

ana.report()

Here is an example of the output:

TOTAL:
MAE: 0.00012021172733511776
RMSE: 0.00034208124270662665


Onsite: 
Si:
MAE: 0.0012505357153713703
RMSE: 0.0023699181620031595
MAE_onsite RMSE_onsite
Hopping: 
Si-Si:
MAE: 0.00016888207755982876
RMSE: 0.0003886453341692686
MAE_hopping RMSE_hopping

If the user wants to see the loss in a decomposed irreps format, one can set the decompose of the HamilLossAnalysis class to True, and rerun the analysis.  We can display the decomposed irreps results using the following code:

import matplotlib.pyplot as plt
import torch

ana_result = ana.stats

for bt, err in ana_result["hopping"].items():
    print("rmse err for bond {bt}: {rmserr} \t mae err for bond {bt}: {maerr}".format(bt=bt, rmserr=err["rmse"], maerr=err["mae"]))

for bt, err in ana_result["onsite"].items():
    print("rmse err for atom {bt}: {rmserr} \t mae err for atom {bt}: {maerr}".format(bt=bt, rmserr=err["rmse"], maerr=err["mae"]))

for bt, err in ana_result["hopping"].items():
    x = list(range(model.idp.orbpair_irreps.num_irreps))
    rmserr = err["rmse_per_irreps"]
    maerr = err["mae_per_irreps"]
    sort_index = torch.LongTensor(model.idp.orbpair_irreps.sort().inv)
    
    # rmserr = rmserr[sort_index]
    # maerr = maerr[sort_index]
    
    plt.figure(figsize=(20,3))
    plt.bar(x, rmserr.cpu().detach(), label="RMSE per rme")
    plt.bar(x, maerr.cpu().detach(), alpha=0.6, label="MAE per rme")
    plt.legend()
    # plt.yscale("log")
    # plt.ylim([1e-5, 5e-4])
    plt.title("rme specific error of bond type: {bt}".format(bt=bt))
    plt.show()

for at, err in ana_result["onsite"].items():
    x = list(range(model.idp.orbpair_irreps.num_irreps))
    rmserr = err["rmse_per_irreps"]
    maerr = err["mae_per_irreps"]
    sort_index = torch.LongTensor(model.idp.orbpair_irreps.sort().inv)

    rmserr = rmserr[sort_index]
    maerr = maerr[sort_index]

    plt.figure(figsize=(20,3))
    plt.bar(x, rmserr.cpu().detach(), label="RMSE per rme")
    plt.bar(x, maerr.cpu().detach(), alpha=0.6, label="MAE per rme")
    plt.legend()
    # plt.yscale("log")
    # plt.ylim([1e-5, 2.e-2])
    plt.title("rme specific error of atom type: {at}".format(at=at))
    plt.show()