Source code for atomgen.models.tokengt

"""Implementation of the TokenGT model."""

from typing import Any, Callable, Optional, Tuple

import torch
import torch.nn.functional as f
from torch import nn
from torch.utils.checkpoint import checkpoint
from transformers import PretrainedConfig, PreTrainedModel


ATOM_METADATA = [
    [
        0.0,
        0.0,
        0.0,
        0.0,
        0.0,
        0.0,
        0.0,
        0.106761565836299,
        0.4573170731707318,
        0.46896368424867707,
        0.0,
        0.0,
        0.0027383806383189145,
        0.0,
        1.0,
        0.0,
        0.0,
    ],
    [
        0.008547008547008548,
        0.010187317385107808,
        0.011235955056179775,
        0.008547008547008548,
        0.008547008547008548,
        0.0,
        1.0,
        0.0,
        -1.0,
        0.9999999999999999,
        2.1731754967921256e-06,
        -1.0,
        0.0,
        0.010000000000000002,
        0.3588318085855031,
        0.0,
        -1.0,
    ],
    [
        0.017094017094017096,
        0.02018415404448405,
        0.02247191011235955,
        0.017094017094017096,
        0.017094017094017096,
        0.16666666666666666,
        0.0,
        0.5729537366548044,
        0.08536585365853658,
        0.0723802160098582,
        0.01302222611458848,
        0.1117635470484688,
        0.2746530986669577,
        0.010000000000000002,
        0.2454609429978888,
        0.16666666666666666,
        0.0,
    ],
    [
        0.025641025641025644,
        0.027228539455021038,
        0.028089887640449437,
        0.025641025641025644,
        0.025641025641025644,
        0.16666666666666666,
        0.058823529411764705,
        0.32384341637010683,
        0.2652439024390244,
        0.2623432478797689,
        0.0451198574701265,
        0.39298038243761085,
        0.4668171696125004,
        0.015,
        0.12181562280084446,
        0.16666666666666666,
        0.14285714285714285,
    ],
    [
        0.03418803418803419,
        0.03334773276914757,
        0.033707865168539325,
        0.03418803418803419,
        0.03418803418803419,
        0.16666666666666666,
        0.7058823529411764,
        0.25266903914590755,
        0.4085365853658537,
        0.2128252833015198,
        0.057071103187614054,
        0.6504807478441018,
        0.715419845245687,
        0.015,
        0.06558761435608726,
        0.16666666666666666,
        0.2857142857142857,
    ],
    [
        0.042735042735042736,
        0.03742946260625253,
        0.033707865168539325,
        0.042735042735042736,
        0.042735042735042736,
        0.16666666666666666,
        0.7647058823529411,
        0.14946619217081855,
        0.5640243902439024,
        0.3559765143644139,
        0.055363782370830124,
        1.0000000000000002,
        0.7324707832177849,
        0.020000000000000004,
        0.04327938071780436,
        0.16666666666666666,
        0.42857142857142855,
    ],
    [
        0.051282051282051294,
        0.04421873990197045,
        0.03932584269662921,
        0.051282051282051294,
        0.051282051282051294,
        0.16666666666666666,
        0.8235294117647058,
        0.09252669039145908,
        0.7134146341463414,
        0.514180781404789,
        2.8295183993586364e-05,
        0.012484827687008686,
        0.012471056032792366,
        0.025,
        0.06657283603096412,
        0.16666666666666666,
        0.5714285714285714,
    ],
    [
        0.05982905982905984,
        0.050994411431564704,
        0.0449438202247191,
        0.05982905982905984,
        0.05982905982905984,
        0.16666666666666666,
        0.8823529411764706,
        0.05693950177935947,
        0.8353658536585367,
        0.4699156740039142,
        3.268543752245935e-05,
        0.00923366315240946,
        0.014660396468409729,
        0.025,
        0.057987332864180154,
        0.16666666666666666,
        0.7142857142857142,
    ],
    [
        0.06837606837606838,
        0.06119533458279619,
        0.056179775280898875,
        0.06837606837606838,
        0.06837606837606838,
        0.16666666666666666,
        0.9411764705882353,
        0.028469750889679707,
        1.0,
        0.6537753400826345,
        3.9270817815768815e-05,
        0.01002929606822616,
        0.01377886297525227,
        0.015,
        0.051372273047149884,
        0.16666666666666666,
        0.8571428571428572,
    ],
    [
        0.07692307692307693,
        0.06521583847234458,
        0.056179775280898875,
        0.07692307692307693,
        0.07692307692307693,
        0.16666666666666666,
        1.0,
        0.007117437722419961,
        -1.0,
        0.8539203131418077,
        1.9758579909666677e-05,
        0.002676173590325307,
        0.003896139326624358,
        0.025,
        0.06586910626319493,
        0.16666666666666666,
        1.0,
    ],
    [
        0.08547008547008549,
        0.07477388917423204,
        0.06741573033707865,
        0.08547008547008549,
        0.08547008547008549,
        0.33333333333333337,
        0.0,
        0.6085409252669042,
        0.07012195121951223,
        0.06017348442747725,
        0.023680786070796774,
        0.09074155275516495,
        0.19638929337502856,
        0.020000000000000004,
        0.07980295566502463,
        0.33333333333333337,
        0.0,
    ],
    [
        0.09401709401709403,
        0.0792467847873929,
        0.06741573033707865,
        0.09401709401709403,
        0.09401709401709403,
        0.33333333333333337,
        0.058823529411764705,
        0.43060498220640586,
        0.1859756097560976,
        0.18132746997849566,
        0.04243692475803745,
        0.23105764525702377,
        0.2316847349772711,
        0.025,
        0.0653764954257565,
        0.33333333333333337,
        0.14285714285714285,
    ],
    [
        0.10256410256410257,
        0.08835244376566789,
        0.07865168539325842,
        0.10256410256410257,
        0.10256410256410257,
        0.33333333333333337,
        0.7058823529411764,
        0.4661921708185055,
        0.2774390243902439,
        0.10108971416145165,
        0.06585161024536003,
        0.23366315240945865,
        0.4753426385985493,
        0.025,
        0.05650950035186488,
        0.33333333333333337,
        0.2857142857142857,
    ],
    [
        0.11111111111111113,
        0.09210763521580445,
        0.07865168539325842,
        0.11111111111111113,
        0.11111111111111113,
        0.33333333333333337,
        0.7647058823529411,
        0.35943060498220647,
        0.36585365853658536,
        0.20575543044917485,
        0.05682720021378778,
        0.4242464682668294,
        0.6025426358703994,
        0.025,
        0.04299788881069668,
        0.33333333333333337,
        0.42857142857142855,
    ],
    [
        0.11965811965811968,
        0.10193099835710374,
        0.0898876404494382,
        0.11965811965811968,
        0.11965811965811968,
        0.33333333333333337,
        0.8235294117647058,
        0.25266903914590755,
        0.4542682926829268,
        0.3185927948389591,
        0.04438814854864767,
        0.07704039807065373,
        0.09357213740327856,
        0.020000000000000004,
        0.04750175932441942,
        0.33333333333333337,
        0.5714285714285714,
    ],
    [
        0.12820512820512822,
        0.10564197106733833,
        0.0898876404494382,
        0.12820512820512822,
        0.12820512820512822,
        0.33333333333333337,
        0.8823529411764706,
        0.21708185053380794,
        0.5731707317073171,
        0.31247009930654546,
        0.05048572289430458,
        0.09515439218602051,
        0.1216720831812958,
        0.035,
        0.04334975369458127,
        0.33333333333333337,
        0.7142857142857142,
    ],
    [
        0.13675213675213677,
        0.11716605497409803,
        0.10112359550561797,
        0.13675213675213677,
        0.13675213675213677,
        0.33333333333333337,
        0.9411764705882353,
        0.17081850533807832,
        0.75,
        0.43848068233986515,
        7.610016686353661e-05,
        0.04019725595612581,
        0.040050948202660634,
        0.04,
        0.0270935960591133,
        0.33333333333333337,
        0.8571428571428572,
    ],
    [
        0.1452991452991453,
        0.13245553465558702,
        0.12359550561797752,
        0.1452991452991453,
        0.1452991452991453,
        0.33333333333333337,
        1.0,
        0.13879003558718864,
        -1.0,
        0.573402276077029,
        4.1222041606379033e-05,
        0.0177390552812359,
        0.01416591926721889,
        0.025,
        0.029978888106966927,
        0.33333333333333337,
        1.0,
    ],
    [
        0.15384615384615385,
        0.12956430935430432,
        0.11235955056179775,
        0.15384615384615385,
        0.15384615384615385,
        0.5,
        0.0,
        0.8220640569395019,
        0.036585365853658514,
        0.021591320946190845,
        0.02102224365609036,
        0.08193366760083631,
        0.17524613028962724,
        0.035,
        0.04665728360309641,
        0.5,
        0.0,
    ],
    [
        0.1623931623931624,
        0.13289772205460673,
        0.11235955056179775,
        0.1623931623931624,
        0.1623931623931624,
        0.5,
        0.058823529411764705,
        0.6085409252669042,
        0.09146341463414634,
        0.10724623674100561,
        0.03755886528151192,
        0.27910065518972543,
        0.2988654305873366,
        0.05500000000000001,
        0.038916256157635463,
        0.5,
        0.14285714285714285,
    ],
    [
        0.17094017094017097,
        0.14948995384243843,
        0.1348314606741573,
        0.17094017094017097,
        0.17094017094017097,
        0.5,
        0.1176470588235294,
        0.5729537366548044,
        0.201219512195122,
        0.128910044216783,
        0.07292479648632205,
        0.4570377290145464,
        0.5293941119700996,
        0.06,
        0.03335679099225897,
        0.5,
        -1.0,
    ],
    [
        0.17948717948717952,
        0.15939155013894887,
        0.14606741573033707,
        0.17948717948717952,
        0.17948717948717952,
        0.5,
        0.1764705882352941,
        0.5373665480427048,
        0.25609756097560976,
        0.14179331674197213,
        0.11072975742939495,
        0.48779542320426544,
        0.6062938422242609,
        0.03,
        0.03019000703729768,
        0.5,
        -1.0,
    ],
    [
        0.18803418803418806,
        0.16985098284653036,
        0.15730337078651685,
        0.18803418803418806,
        0.18803418803418806,
        0.5,
        0.23529411764705882,
        0.5017793594306051,
        0.2835365853658536,
        0.13783555222654456,
        0.14902252432012042,
        0.5493108115837035,
        0.6267549677907782,
        0.03,
        0.027797325826882473,
        0.5,
        -1.0,
    ],
    [
        0.1965811965811966,
        0.17343610222012087,
        0.15730337078651685,
        0.1965811965811966,
        0.1965811965811966,
        0.5,
        0.2941176470588235,
        0.5017793594306051,
        0.29268292682926833,
        0.13881653659361634,
        0.1743884335980532,
        0.5378719996949651,
        0.5012600643161381,
        0.03,
        0.024982406755805767,
        0.5,
        -1.0,
    ],
    [
        0.20512820512820515,
        0.1834431432040899,
        0.16853932584269662,
        0.20512820512820515,
        0.20512820512820515,
        0.5,
        0.3529411764705882,
        0.4661921708185055,
        0.25914634146341464,
        0.17107304225964678,
        0.1814616198390152,
        0.38255835382787134,
        0.3972493426863412,
        0.04,
        0.0270935960591133,
        0.5,
        -1.0,
    ],
    [
        0.2136752136752137,
        0.18652825067263504,
        0.16853932584269662,
        0.2136752136752137,
        0.2136752136752137,
        0.5,
        0.4117647058823529,
        0.43060498220640586,
        0.3445121951219513,
        0.19370816923188441,
        0.1919494477135451,
        0.4560209457355474,
        0.5336568464631241,
        0.035,
        0.024982406755805767,
        0.5,
        -1.0,
    ],
    [
        0.22222222222222224,
        0.19703190212011848,
        0.1797752808988764,
        0.22222222222222224,
        0.22222222222222224,
        0.5,
        0.47058823529411764,
        0.43060498220640586,
        0.3597560975609756,
        0.19267402807644915,
        0.2160958421223465,
        0.4458531129455577,
        0.5449104655247086,
        0.05500000000000001,
        0.023011963406052074,
        0.5,
        -1.0,
    ],
    [
        0.2307692307692308,
        0.19621555615269748,
        0.1741573033707865,
        0.2307692307692308,
        0.2307692307692308,
        0.5,
        0.5294117647058824,
        0.3950177935943062,
        0.36890243902439024,
        0.18101819411892625,
        0.2173153569914779,
        0.4351768885160684,
        0.5425233342086149,
        0.04,
        0.024630541871921183,
        0.5,
        -1.0,
    ],
    [
        0.23931623931623935,
        0.21272275190225615,
        0.19662921348314605,
        0.23931623931623935,
        0.23931623931623935,
        0.5,
        0.5882352941176471,
        0.3950177935943062,
        0.36585365853658536,
        0.1852030830937251,
        0.2185348718606093,
        0.3415311485202626,
        0.4826745419265514,
        0.04,
        0.02047853624208304,
        0.5,
        -1.0,
    ],
    [
        0.2478632478632479,
        0.2189609956699649,
        0.19662921348314605,
        0.2478632478632479,
        0.2478632478632479,
        0.5,
        0.6470588235294117,
        0.35943060498220647,
        0.28963414634146345,
        0.2657984391233962,
        0.17390062765040062,
        0.17252397384325016,
        0.20048151848833204,
        0.06,
        0.020689655172413793,
        0.5,
        -1.0,
    ],
    [
        0.2564102564102564,
        0.23373345623875397,
        0.2191011235955056,
        0.2564102564102564,
        0.2564102564102564,
        0.5,
        0.7058823529411764,
        0.4661921708185055,
        0.33841463414634154,
        0.10174209292773093,
        0.14414446484359486,
        0.0733952300154424,
        0.4216321839864411,
        0.05500000000000001,
        0.019493314567206193,
        0.5,
        0.2857142857142857,
    ],
    [
        0.26495726495726496,
        0.24365546118444995,
        0.2303370786516854,
        0.26495726495726496,
        0.26495726495726496,
        0.5,
        0.7647058823529411,
        0.35943060498220647,
        0.39939024390243894,
        0.19356319617271125,
        0.12975418938784455,
        0.30434230009087504,
        0.5288825838309366,
        0.07,
        0.015904292751583393,
        0.5,
        0.42857142857142855,
    ],
    [
        0.27350427350427353,
        0.2514175507580112,
        0.23595505617977527,
        0.27350427350427353,
        0.27350427350427353,
        0.5,
        0.8235294117647058,
        0.2882562277580072,
        0.451219512195122,
        0.28485756396936235,
        0.1409737261838533,
        0.2735083471552311,
        0.15052227023008535,
        0.05500000000000001,
        0.016537649542575653,
        0.5,
        0.5714285714285714,
    ],
    [
        0.28205128205128205,
        0.2651525716598694,
        0.25280898876404495,
        0.28205128205128205,
        0.28205128205128205,
        0.5,
        0.8823529411764706,
        0.25266903914590755,
        0.5640243902439024,
        0.2831082223886727,
        0.11731513772270441,
        0.12200763858438347,
        0.1626284361902748,
        0.085,
        0.01597466572836031,
        0.5,
        0.7142857142857142,
    ],
    [
        0.2905982905982906,
        0.2683635324650587,
        0.25280898876404495,
        0.2905982905982906,
        0.2905982905982906,
        0.5,
        0.9411764705882353,
        0.21708185053380794,
        0.6890243902439025,
        0.38272404378186387,
        0.07609553514606364,
        0.06402557209946683,
        0.055889564484942325,
        0.08,
        0.026741731175228708,
        0.5,
        0.8571428571428572,
    ],
    [
        0.29914529914529914,
        0.2816087457864643,
        0.2696629213483146,
        0.29914529914529914,
        0.29914529914529914,
        0.5,
        1.0,
        0.18149466192170824,
        -1.0,
        0.48835141469543564,
        8.878312150250298e-05,
        0.025865695638635226,
        0.019729640327514418,
        0.1,
        0.010837438423645322,
        0.5,
        1.0,
    ],
    [
        0.3076923076923077,
        0.28728915314310205,
        0.2696629213483146,
        0.3076923076923077,
        0.3076923076923077,
        0.6666666666666666,
        0.0,
        0.8932384341637012,
        0.036585365853658514,
        0.013685456785947292,
        0.03731496230768564,
        0.07590668471456988,
        0.16313996432943775,
        0.085,
        0.01893033075299085,
        0.6666666666666666,
        0.0,
    ],
    [
        0.3162393162393162,
        0.2946090553176436,
        0.2808988764044944,
        0.3162393162393162,
        0.3162393162393162,
        0.6666666666666666,
        0.058823529411764705,
        0.7153024911032031,
        0.07621951219512194,
        0.08703215985695992,
        0.06438819240240236,
        0.26130694780724334,
        0.2814734738557968,
        0.075,
        0.014567206192821956,
        0.6666666666666666,
        0.14285714285714285,
    ],
    [
        0.3247863247863248,
        0.29898330912640775,
        0.2808988764044944,
        0.3247863247863248,
        0.3247863247863248,
        0.6666666666666666,
        0.1176470588235294,
        0.6441281138790037,
        0.15853658536585366,
        0.11227680189431463,
        0.109022436612611,
        0.4537331833577997,
        0.6146488018305888,
        0.09,
        0.014356087262491202,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.3333333333333333,
        0.3068678505950822,
        0.28651685393258425,
        0.3333333333333333,
        0.3333333333333333,
        0.6666666666666666,
        0.1764705882352941,
        0.6085409252669042,
        0.19207317073170735,
        0.1324087273781622,
        0.15877864327317145,
        0.5366010205962164,
        0.7976053662711987,
        0.085,
        0.012948627726952853,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.3418803418803419,
        0.312589075250091,
        0.29213483146067415,
        0.3418803418803419,
        0.3418803418803419,
        0.6666666666666666,
        0.23529411764705882,
        0.5729537366548044,
        0.27439024390243905,
        0.13844927151037764,
        0.2090226558813845,
        0.6931856455620589,
        0.8547260084777264,
        0.105,
        0.012033779028852921,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.35042735042735046,
        0.3229770776855231,
        0.3033707865168539,
        0.35042735042735046,
        0.35042735042735046,
        0.6666666666666666,
        0.2941176470588235,
        0.5373665480427048,
        0.44512195121951226,
        0.15456544325512842,
        0.24877884061506758,
        0.7310608227047707,
        0.8368225236070237,
        0.085,
        0.011048557353976075,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.358974358974359,
        0.3299160184086015,
        0.3089887640449438,
        0.358974358974359,
        0.358974358974359,
        0.6666666666666666,
        0.3529411764705882,
        0.5373665480427048,
        0.36585365853658536,
        0.16363109188875738,
        0.28048622721248356,
        0.6250611658691273,
        0.8774037559806166,
        0.1,
        -1.0,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.36752136752136755,
        0.3403584439085284,
        0.3202247191011236,
        0.36752136752136755,
        0.36752136752136755,
        0.6666666666666666,
        0.4117647058823529,
        0.5017793594306051,
        0.4573170731707318,
        0.16752120230990405,
        0.30243749485684845,
        0.6377709568566146,
        0.7534434369234653,
        0.065,
        0.010133708655876143,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.37606837606837606,
        0.346603490559299,
        0.3258426966292135,
        0.37606837606837606,
        0.37606837606837606,
        0.6666666666666666,
        0.47058823529411764,
        0.4661921708185055,
        0.4817073170731707,
        0.17227631865078405,
        0.30243749485684845,
        0.5655793440476872,
        0.67586166915042,
        0.085,
        0.01048557353976073,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.38461538461538464,
        0.35855615609895475,
        0.33707865168539325,
        0.38461538461538464,
        0.38461538461538464,
        0.6666666666666666,
        0.5294117647058824,
        0.4661921708185055,
        0.4573170731707318,
        0.21470510063546525,
        0.2926813759037974,
        0.4603422746712931,
        0.5510488031946638,
        0.09,
        0.01055594651653765,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.39316239316239315,
        0.363481443435728,
        0.34269662921348315,
        0.39316239316239315,
        0.39316239316239315,
        0.6666666666666666,
        0.5882352941176471,
        0.4661921708185055,
        0.375,
        0.17794476526445505,
        0.25609592982985585,
        0.31011254519919423,
        0.41447079003816,
        0.12000000000000001,
        0.00992258972554539,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.4017094017094017,
        0.37893419231070125,
        0.3595505617977528,
        0.4017094017094017,
        0.4017094017094017,
        0.6666666666666666,
        0.6470588235294117,
        0.43060498220640586,
        0.301829268292683,
        0.24644936815908378,
        0.21194949156729978,
        0.1474729758069129,
        0.17661020532739505,
        0.095,
        0.00971147079521464,
        0.6666666666666666,
        -1.0,
    ],
    [
        0.4102564102564103,
        0.38712146207562764,
        0.3707865168539326,
        0.4102564102564103,
        0.4102564102564103,
        0.6666666666666666,
        0.7058823529411764,
        0.5373665480427048,
        0.3292682926829269,
        0.09145383816174166,
        0.1782908811792736,
        0.10567809912365993,
        0.39912494586327196,
        0.15500000000000003,
        0.009781843771991556,
        0.6666666666666666,
        0.2857142857142857,
    ],
    [
        0.4188034188034188,
        0.40035987251397137,
        0.38764044943820225,
        0.4188034188034188,
        0.4188034188034188,
        0.6666666666666666,
        0.7647058823529411,
        0.43060498220640586,
        0.38414634146341464,
        0.16671901804914588,
        0.17780307523162106,
        0.12481904435081566,
        0.4894949171153905,
        0.125,
        0.009429978888106968,
        0.6666666666666666,
        0.42857142857142855,
    ],
    [
        0.4273504273504274,
        0.4107342691832799,
        0.398876404494382,
        0.4273504273504274,
        0.4273504273504274,
        0.6666666666666666,
        0.8235294117647058,
        0.35943060498220647,
        0.41158536585365846,
        0.22782516249063714,
        0.16316889680204447,
        0.22620250509980366,
        0.3164278966985974,
        0.13,
        0.007952146375791697,
        0.6666666666666666,
        0.5714285714285714,
    ],
    [
        0.4358974358974359,
        0.43059868772385734,
        0.42696629213483145,
        0.4358974358974359,
        0.4358974358974359,
        0.6666666666666666,
        0.8823529411764706,
        0.32384341637010683,
        0.426829268292683,
        0.24721289293739582,
        0.15194936000603573,
        0.1801295127701625,
        0.21429277824573129,
        0.13,
        0.007600281491907108,
        0.6666666666666666,
        0.7142857142857142,
    ],
    [
        0.4444444444444445,
        0.42823128441833647,
        0.4157303370786517,
        0.4444444444444445,
        0.4444444444444445,
        0.6666666666666666,
        0.9411764705882353,
        0.2882562277580072,
        0.5975609756097562,
        0.31688211274071565,
        0.12024197340861972,
        0.09468158796128598,
        0.07727144070195302,
        0.105,
        0.008444757213230118,
        0.6666666666666666,
        0.8571428571428572,
    ],
    [
        0.452991452991453,
        0.4431602112975479,
        0.43258426966292135,
        0.452991452991453,
        0.452991452991453,
        0.6666666666666666,
        1.0,
        0.25266903914590755,
        -1.0,
        0.39799453934810447,
        0.0001414661638489788,
        0.03743668935364358,
        0.027419613352930545,
        0.14,
        0.00450387051372273,
        0.6666666666666666,
        1.0,
    ],
    [
        0.46153846153846156,
        0.4486433350453922,
        0.4382022471910112,
        0.46153846153846156,
        0.46153846153846156,
        0.8333333333333334,
        0.0,
        1.0000000000000002,
        0.0274390243902439,
        0.0,
        0.045607663417779054,
        0.0730876530735452,
        0.16024130487418112,
        0.095,
        0.010415200562983815,
        0.8333333333333334,
        0.0,
    ],
    [
        0.47008547008547014,
        0.463684509495124,
        0.4550561797752809,
        0.47008547008547014,
        0.47008547008547014,
        0.8333333333333334,
        0.058823529411764705,
        0.8220640569395019,
        0.05792682926829271,
        0.06368183245946799,
        0.08755897491589865,
        0.25113911501725356,
        0.36928580441210074,
        0.11,
        0.007741027445460941,
        0.8333333333333334,
        0.14285714285714285,
    ],
    [
        0.47863247863247865,
        0.46905198423091704,
        0.4606741573033708,
        0.47863247863247865,
        0.47863247863247865,
        0.8333333333333334,
        0.1176470588235294,
        0.7864768683274024,
        0.12195121951219515,
        0.08132988619614859,
        0.14999813621542551,
        0.2996905165894547,
        0.6364740024348741,
        0.08,
        0.007107670654468685,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.4871794871794872,
        0.47317112992486215,
        0.4606741573033708,
        0.4871794871794872,
        0.4871794871794872,
        0.8333333333333334,
        -1.0,
        0.7864768683274024,
        0.1280487804878049,
        0.07948389590934354,
        0.16512012059265466,
        0.2686786265799859,
        0.6328933054607335,
        0.08,
        0.006896551724137932,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.4957264957264958,
        0.47586507161735137,
        0.4606741573033708,
        0.4957264957264958,
        0.4957264957264958,
        0.8333333333333334,
        -1.0,
        0.7864768683274024,
        0.13109756097560973,
        0.07630898591345106,
        0.16512012059265466,
        0.3024866706067019,
        0.6460225276992488,
        0.06,
        0.006966924700914849,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5042735042735044,
        0.48720547768144135,
        0.47191011235955055,
        0.5042735042735044,
        0.5042735042735044,
        0.8333333333333334,
        -1.0,
        0.7508896797153027,
        0.13414634146341461,
        0.07882185227245272,
        0.1709737919644853,
        0.3240933152854302,
        0.5699753443436925,
        0.065,
        0.006755805770584096,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5128205128205129,
        0.4897837703618793,
        0.47191011235955055,
        0.5128205128205129,
        0.5128205128205129,
        0.8333333333333334,
        -1.0,
        0.7508896797153027,
        0.13109756097560973,
        0.08157634039674291,
        0.17707136631014223,
        0.3024866706067019,
        0.55735765024434,
        0.05500000000000001,
        -1.0,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5213675213675214,
        0.5080154969676149,
        0.4943820224719101,
        0.5213675213675214,
        0.5213675213675214,
        0.8333333333333334,
        -1.0,
        0.7508896797153027,
        0.14329268292682926,
        0.0845579529804045,
        0.18341284362962543,
        0.33832828119141584,
        0.35172333830083996,
        0.07,
        0.007248416608022519,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.52991452991453,
        0.5134714091832119,
        0.5,
        0.52991452991453,
        0.52991452991453,
        0.8333333333333334,
        -1.0,
        0.7508896797153027,
        0.1524390243902439,
        0.08584821320704569,
        0.12780296559723434,
        0.27477932625397977,
        0.30653835267478063,
        0.09,
        0.0061928219563687536,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5384615384615385,
        0.5314514291156592,
        0.5224719101123595,
        0.5384615384615385,
        0.5384615384615385,
        0.8333333333333334,
        -1.0,
        0.7153024911032031,
        0.1524390243902439,
        0.10902940536883562,
        0.19268115663502394,
        0.3993352779313545,
        0.6039067109081672,
        0.07,
        0.009992962702322309,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5470085470085471,
        0.5371488436799516,
        0.5280898876404494,
        0.5470085470085471,
        0.5470085470085471,
        0.8333333333333334,
        -1.0,
        0.7153024911032031,
        0.1524390243902439,
        0.09519414308840943,
        0.20072995477129107,
        0.41077408982009295,
        0.596574807580165,
        0.105,
        0.0061928219563687536,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5555555555555557,
        0.5493089971529934,
        0.5449438202247191,
        0.5555555555555557,
        0.5555555555555557,
        0.8333333333333334,
        -1.0,
        0.7153024911032031,
        0.15853658536585366,
        0.09882330200304443,
        0.20853484993373195,
        0.42348388080758015,
        0.48352708882515627,
        0.09,
        0.005348346235045743,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5641025641025642,
        0.557574500073131,
        0.550561797752809,
        0.5641025641025642,
        0.5641025641025642,
        0.8333333333333334,
        -1.0,
        0.7153024911032031,
        0.16158536585365854,
        0.10281489356561238,
        0.21463242427938886,
        0.43949821745181405,
        0.509615023922466,
        0.13,
        0.004996481351161154,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5726495726495727,
        0.5654964573986455,
        0.5561797752808989,
        0.5726495726495727,
        0.5726495726495727,
        0.8333333333333334,
        -1.0,
        0.7153024911032031,
        0.16463414634146342,
        0.10698045279918819,
        0.22121780457269832,
        0.45271640007880076,
        0.596574807580165,
        0.065,
        0.005207600281491907,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5811965811965812,
        0.5711938719629379,
        0.5617977528089888,
        0.5811965811965812,
        0.5811965811965812,
        0.8333333333333334,
        -1.0,
        0.6797153024911033,
        0.1676829268292683,
        0.11068209824340977,
        0.22731537891835524,
        0.4585629039330449,
        0.37832280153731257,
        0.075,
        0.004644616467276566,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5897435897435899,
        0.5852078110703316,
        0.5786516853932584,
        0.5897435897435899,
        0.5897435897435899,
        0.8333333333333334,
        -1.0,
        0.6797153024911033,
        0.12195121951219515,
        0.11405997052214464,
        0.1699981800691802,
        0.2752877178934793,
        0.24975872922769482,
        0.065,
        0.004292751583391977,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.5982905982905984,
        0.5917147687189832,
        0.5842696629213483,
        0.5982905982905984,
        0.5982905982905984,
        0.8333333333333334,
        -1.0,
        0.6441281138790037,
        0.17378048780487806,
        0.07403290888443234,
        0.2399983335573216,
        0.4885580106635147,
        0.6259024208921734,
        0.095,
        0.00422237860661506,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6068376068376069,
        0.6036980472324172,
        0.5955056179775281,
        0.6068376068376069,
        0.6068376068376069,
        0.8333333333333334,
        0.1764705882352941,
        0.6085409252669042,
        0.1829268292682927,
        0.14164834368279897,
        0.32438876250121335,
        0.6319244530023704,
        0.8306841859370685,
        0.07,
        0.0035186488388458817,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6153846153846155,
        0.6120587905154204,
        0.6067415730337078,
        0.6153846153846155,
        0.6153846153846155,
        0.8333333333333334,
        0.23529411764705882,
        0.5729537366548044,
        0.2439024390243902,
        0.1766593374731196,
        0.40731577360214744,
        0.8274010383899237,
        0.9764697055985051,
        0.08,
        0.003237156931738213,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.623931623931624,
        0.6218957594228435,
        0.6179775280898876,
        0.623931623931624,
        0.623931623931624,
        0.8333333333333334,
        0.2941176470588235,
        0.5373665480427048,
        0.5060975609756098,
        0.19185251407446782,
        0.4707305467969794,
        0.9318755203070687,
        0.99300911543144,
        0.095,
        0.002674173117522871,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6324786324786326,
        0.6299469715265329,
        0.6235955056179775,
        0.6324786324786326,
        0.6324786324786326,
        0.8333333333333334,
        0.3529411764705882,
        0.5373665480427048,
        0.36585365853658536,
        0.19037862130620728,
        0.5121940523474464,
        0.8741730692238767,
        1.0,
        0.09,
        0.00302603800140746,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6410256410256411,
        0.6436309708054273,
        0.6404494382022472,
        0.6410256410256411,
        0.6410256410256411,
        0.8333333333333334,
        0.4117647058823529,
        0.5017793594306051,
        0.4573170731707318,
        0.21960035760021265,
        0.5512185281596508,
        0.8352811088021659,
        0.9004225222429487,
        0.08,
        0.0025334271639690367,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6495726495726497,
        0.650389635127367,
        0.6460674157303371,
        0.6495726495726497,
        0.6495726495726497,
        0.8333333333333334,
        0.47058823529411764,
        0.5017793594306051,
        0.4573170731707318,
        0.24515427549713678,
        0.5512185281596508,
        0.6868307500683152,
        0.8008450444858972,
        0.11,
        0.002603800140745954,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6581196581196582,
        0.6601415679965169,
        0.6573033707865168,
        0.6581196581196582,
        0.6581196581196582,
        0.8333333333333334,
        0.5294117647058824,
        0.4661921708185055,
        0.4817073170731707,
        0.2447531833667577,
        0.5243892010387603,
        0.5162653550162368,
        0.6980278885141472,
        0.14500000000000002,
        0.00274454609429979,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6666666666666667,
        0.6665464823992409,
        0.6629213483146067,
        0.6666666666666667,
        0.6666666666666667,
        0.8333333333333334,
        0.5882352941176471,
        0.4661921708185055,
        0.5609756097560976,
        0.25764612076255833,
        0.4707305467969794,
        0.33644214820887275,
        0.5328042995645191,
        0.09,
        0.002463054187192118,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6752136752136753,
        0.6788699050657669,
        0.6797752808988764,
        0.6752136752136753,
        0.6752136752136753,
        0.8333333333333334,
        0.6470588235294117,
        0.4661921708185055,
        0.39634146341463417,
        0.31621523666851903,
        0.3292668219777389,
        0.05598790027897992,
        0.1067013596417939,
        0.115,
        0.003237156931738213,
        0.8333333333333334,
        -1.0,
    ],
    [
        0.6837606837606839,
        0.6917715727925495,
        0.6910112359550562,
        0.6837606837606839,
        0.6837606837606839,
        0.8333333333333334,
        0.7058823529411764,
        0.5729537366548044,
        0.4085365853658537,
        0.10700461497571703,
        0.2902423461655346,
        0.14310589162361226,
        0.29698982741040586,
        0.125,
        0.002463054187192118,
        0.8333333333333334,
        0.2857142857142857,
    ],
    [
        0.6923076923076924,
        0.7013534335851533,
        0.7022471910112359,
        0.6923076923076924,
        0.6923076923076924,
        0.8333333333333334,
        0.7647058823529411,
        0.4661921708185055,
        0.4969512195121951,
        0.1702370309517481,
        0.27560816773595803,
        0.14910491296970624,
        0.3440504162133959,
        0.13,
        0.002463054187192118,
        0.8333333333333334,
        0.42857142857142855,
    ],
    [
        0.7008547008547009,
        0.7074079995101924,
        0.7078651685393258,
        0.7008547008547009,
        0.7008547008547009,
        0.8333333333333334,
        0.8235294117647058,
        0.3950177935943062,
        0.4024390243902439,
        0.1639017082658806,
        0.2392666246358428,
        0.13484961139814058,
        0.31250618096501487,
        0.08,
        0.0019704433497536944,
        0.8333333333333334,
        0.5714285714285714,
    ],
    [
        0.7094017094017095,
        0.7108774698717316,
        0.7078651685393258,
        0.7094017094017095,
        0.7094017094017095,
        0.8333333333333334,
        0.8823529411764706,
        0.35943060498220647,
        0.39634146341463417,
        0.21857588131538888,
        0.22731537891835524,
        0.13039610063612506,
        0.20985953437298585,
        0.15500000000000003,
        -1.0,
        0.8333333333333334,
        0.7142857142857142,
    ],
    [
        0.7179487179487181,
        0.7108774698717316,
        0.7022471910112359,
        0.7179487179487181,
        0.7179487179487181,
        0.8333333333333334,
        0.9411764705882353,
        0.32384341637010683,
        0.4573170731707318,
        0.26124628506535874,
        0.17072988899065902,
        0.14259749998411278,
        0.10329117204737433,
        0.09,
        -1.0,
        0.8333333333333334,
        0.8571428571428572,
    ],
    [
        0.7264957264957266,
        0.7516947682427814,
        0.7640449438202247,
        0.7264957264957266,
        0.7264957264957266,
        0.8333333333333334,
        1.0,
        0.2882562277580072,
        -1.0,
        0.3312441104694711,
        0.00023512490579826907,
        0.047782459217458176,
        0.03530908235262022,
        0.085,
        0.0,
        0.8333333333333334,
        1.0,
    ],
    [
        0.7350427350427351,
        0.7550962097737021,
        0.7640449438202247,
        0.7350427350427351,
        0.7350427350427351,
        0.9999999999999999,
        0.0,
        -1.0,
        0.0,
        0.00864039432672098,
        0.045607663417779054,
        0.0726936495529331,
        0.161264361152507,
        0.09,
        -1.0,
        0.9999999999999999,
        0.0,
    ],
    [
        0.7435897435897437,
        0.7653005343664645,
        0.7752808988764045,
        0.7435897435897437,
        0.7435897435897437,
        0.9999999999999999,
        0.058823529411764705,
        -1.0,
        0.06097560975609759,
        0.06690506680841815,
        0.1341444429167175,
        0.243767436244511,
        0.34200430365674417,
        0.06,
        -1.0,
        0.9999999999999999,
        0.14285714285714285,
    ],
    [
        0.7521367521367522,
        0.7687019758973853,
        0.7752808988764045,
        0.7521367521367522,
        0.7521367521367522,
        0.9999999999999999,
        0.1176470588235294,
        -1.0,
        0.12195121951219515,
        0.061666706936960886,
        0.24633981087680482,
        0.3327359731569215,
        0.5911185074290938,
        0.04,
        0.0018296973961998584,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.7606837606837608,
        0.7858384383301644,
        0.797752808988764,
        0.7606837606837608,
        0.7606837606837608,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.11659699905767515,
        0.28536428668900904,
        0.5119440260804912,
        0.8622284211854495,
        0.045,
        0.001337086558761435,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.7692307692307694,
        0.7824301939161817,
        0.7865168539325842,
        0.7692307692307694,
        0.7692307692307694,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.2439024390243902,
        0.09646024113852172,
        0.3756083870047315,
        0.4725436740192808,
        0.7324707832177849,
        0.05500000000000001,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.7777777777777779,
        0.8062164745419109,
        0.8202247191011236,
        0.7777777777777779,
        0.7777777777777779,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.2073170731707317,
        0.11115567690337544,
        0.4634134575821911,
        0.3535800303764005,
        0.7502037587087667,
        0.06,
        0.00154820548909219,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.7863247863247864,
        0.8027163912065933,
        0.8089887640449438,
        0.7863247863247864,
        0.7863247863247864,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.201219512195122,
        0.11461570058230844,
        0.49999890365613264,
        0.22851568705952632,
        0.7278670299653185,
        0.75,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.7948717948717949,
        0.826526481923039,
        0.8426966292134831,
        0.7948717948717949,
        0.7948717948717949,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.17682926829268295,
        0.10304201802498372,
        0.4829256954882932,
        0.22851568705952632,
        0.5962337888207231,
        0.8,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8034188034188036,
        0.8231250403921182,
        0.8314606741573034,
        0.8034188034188036,
        0.8034188034188036,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.10050982192475896,
        0.3341448814542644,
        0.3185010072509358,
        0.4903474640139954,
        0.65,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8119658119658121,
        0.8367308065158015,
        0.848314606741573,
        0.8119658119658121,
        0.8119658119658121,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.10136516297388068,
        0.3292668219777389,
        0.3370573020926671,
        0.5761136820136477,
        0.65,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8205128205128206,
        0.8367308065158015,
        0.8426966292134831,
        0.8205128205128206,
        0.8205128205128206,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.11133930944499479,
        0.3609742085751549,
        0.31646744069293786,
        0.16689117068329928,
        0.4,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8290598290598292,
        0.8503365726394846,
        0.8595505617977528,
        0.8290598290598292,
        0.8290598290598292,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.11538889023123203,
        0.3682912977899432,
        0.48576185664626753,
        0.1992879528302852,
        0.6,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8376068376068377,
        0.8537380141704054,
        0.8595505617977528,
        0.8376068376068377,
        0.8376068376068377,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.12207214825911519,
        0.3292668219777389,
        0.28443876740447005,
        -1.0,
        0.6,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8461538461538463,
        0.8707452218250095,
        0.8820224719101123,
        0.8461538461538463,
        0.8461538461538463,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.12593809650373308,
        -1.0,
        -1.0,
        -1.0,
        0.5,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8547008547008548,
        0.8741466633559303,
        0.8820224719101123,
        0.8547008547008548,
        0.8547008547008548,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.12980404474835092,
        -1.0,
        -1.0,
        -1.0,
        0.15000000000000002,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8632478632478634,
        0.8775481048868511,
        0.8820224719101123,
        0.8632478632478634,
        0.8632478632478634,
        0.9999999999999999,
        -1.0,
        -1.0,
        0.1829268292682927,
        0.1331867494623916,
        -1.0,
        -1.0,
        -1.0,
        0.35,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8717948717948719,
        0.8877524294796135,
        0.8932584269662921,
        0.8717948717948719,
        0.8717948717948719,
        0.9999999999999999,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        1.0000000000000002,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8803418803418804,
        0.8843509879486927,
        0.8820224719101123,
        0.8803418803418804,
        0.8803418803418804,
        0.9999999999999999,
        0.1764705882352941,
        -1.0,
        -1.0,
        -1.0,
        0.4414621899378262,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8888888888888891,
        0.8877524294796135,
        0.8820224719101123,
        0.8888888888888891,
        0.8888888888888891,
        0.9999999999999999,
        0.23529411764705882,
        -1.0,
        -1.0,
        -1.0,
        0.9512194052347446,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.8974358974358976,
        0.9013581956032967,
        0.898876404494382,
        0.8974358974358976,
        0.8974358974358976,
        0.9999999999999999,
        0.2941176470588235,
        -1.0,
        -1.0,
        -1.0,
        0.8536582157042338,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9059829059829061,
        0.894555312541455,
        0.8820224719101123,
        0.9059829059829061,
        0.9059829059829061,
        0.9999999999999999,
        0.3529411764705882,
        -1.0,
        -1.0,
        -1.0,
        0.9024388104694893,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9145299145299146,
        0.9047596371342175,
        0.8932584269662921,
        0.9145299145299146,
        0.9145299145299146,
        0.9999999999999999,
        0.4117647058823529,
        -1.0,
        -1.0,
        -1.0,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9230769230769232,
        0.9081610786651383,
        0.8932584269662921,
        0.9230769230769232,
        0.9230769230769232,
        0.9999999999999999,
        0.47058823529411764,
        -1.0,
        -1.0,
        -1.0,
        0.8536582157042338,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9316239316239318,
        0.9183654032579007,
        0.9044943820224719,
        0.9316239316239318,
        0.9316239316239318,
        0.9999999999999999,
        0.5294117647058824,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9401709401709403,
        0.9217668447888215,
        0.9044943820224719,
        0.9401709401709403,
        0.9401709401709403,
        0.9999999999999999,
        0.5882352941176471,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9487179487179489,
        0.965985584690792,
        0.9719101123595505,
        0.9487179487179489,
        0.9487179487179489,
        0.9999999999999999,
        0.6470588235294117,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        -1.0,
    ],
    [
        0.9572649572649574,
        0.9625841431598712,
        0.9606741573033708,
        0.9572649572649574,
        0.9572649572649574,
        0.9999999999999999,
        0.7058823529411764,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        0.2857142857142857,
    ],
    [
        0.9658119658119659,
        0.9795913508144752,
        0.9831460674157303,
        0.9658119658119659,
        0.9658119658119659,
        0.9999999999999999,
        0.7647058823529411,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        0.42857142857142855,
    ],
    [
        0.9743589743589745,
        0.9761899092835544,
        0.9719101123595505,
        0.9743589743589745,
        0.9743589743589745,
        0.9999999999999999,
        0.8235294117647058,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        0.5714285714285714,
    ],
    [
        0.9829059829059831,
        0.9897956754072376,
        0.9887640449438202,
        0.9829059829059831,
        0.9829059829059831,
        0.9999999999999999,
        0.8823529411764706,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        0.7142857142857142,
    ],
    [
        0.9914529914529915,
        1.0,
        1.0,
        0.9914529914529915,
        0.9914529914529915,
        0.9999999999999999,
        0.9411764705882353,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        0.8571428571428572,
    ],
    [
        1.0000000000000002,
        0.9965985584690792,
        0.9887640449438202,
        1.0000000000000002,
        1.0000000000000002,
        0.9999999999999999,
        1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        0.9999999999999999,
        1.0,
    ],
]


[docs] class ParallelBlock(nn.Module): """Parallel transformer block.""" def __init__( self, dim: int, num_heads: int, mlp_ratio: int = 4, dropout: float = 0.0, ): super().__init__() assert ( dim % num_heads == 0 ), f"dim {dim} should be divisible by num_heads {num_heads}" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.mlp_hidden_dim = int(mlp_ratio * dim) self.dropout = dropout self.proj_drop = nn.Dropout(self.dropout) self.in_proj_in_dim = dim self.in_proj_out_dim = self.mlp_hidden_dim + 3 * dim self.out_proj_in_dim = self.mlp_hidden_dim + dim self.out_proj_out_dim = 2 * dim self.in_split = [self.mlp_hidden_dim] + [dim] * 3 self.out_split = [dim] * 2 self.in_norm = nn.LayerNorm(dim) self.q_norm = nn.LayerNorm(self.head_dim) self.k_norm = nn.LayerNorm(self.head_dim) self.in_proj = nn.Linear(self.in_proj_in_dim, self.in_proj_out_dim, bias=False) self.in_proj = nn.Linear(dim, dim * mlp_ratio) self.act_fn = nn.GELU() self.out_proj = nn.Linear( self.out_proj_in_dim, self.out_proj_out_dim, bias=False )
[docs] def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Forward function call for the parallel transformer block.""" b, n, c = x.shape res = x x = self.in_norm(x) x = self.in_proj(self.in_norm(x)) x, q, k, v = torch.split(x, self.in_split, dim=-1) x = self.act_fn(x) x = self.proj_drop(x) q = self.q_norm(q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) k = self.k_norm(k.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)) v = v.view(b, n, self.num_heads, self.head_dim).transpose(1, 2) x_attn = ( f.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=self.dropout ) .transpose(1, 2) .reshape(b, n, c) ) x_mlp, x_attn = self.out_proj(torch.cat([x, x_attn], dim=-1)).split( self.out_split, dim=-1 ) out: torch.Tensor = x_mlp + x_attn + res return out
[docs] class TransformerConfig(PretrainedConfig): # type: ignore """Configuration class to store the configuration of a TokenGT model.""" def __init__( self, vocab_size: int = 123, dim: int = 768, num_heads: int = 12, depth: int = 12, mlp_ratio: int = 4, k: int = 16, sigma: float = 0.03, type_id_dim: int = 64, dropout: float = 0.0, mask_token_id: int = 0, pad_token_id: int = 119, bos_token_id: int = 120, eos_token_id: int = 121, cls_token_id: int = 122, gradient_checkpointing: bool = False, **kwargs: Any, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.dim = dim self.num_heads = num_heads self.depth = depth self.mlp_ratio = mlp_ratio self.k = k self.sigma = sigma self.type_id_dim = type_id_dim self.dropout = dropout self.mask_token_id = mask_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.cls_token_id = cls_token_id self.gradient_checkpointing = gradient_checkpointing
[docs] class TransformerEncoder(nn.Module): """Transformer encoder for atom modeling.""" def __init__(self, config: TransformerConfig): super().__init__() self.vocab_size = config.vocab_size self.dim = config.dim self.num_heads = config.num_heads self.depth = config.depth self.mlp_ratio = config.mlp_ratio self.k = config.k self.sigma = config.sigma self.type_id_dim = config.type_id_dim self.gradient_checkpointing = config.gradient_checkpointing self.dropout = config.dropout self.metadata_vocab = nn.Embedding(122, 17) vocab_weight = torch.empty(122, 17).fill_(-1.0) vocab_weight[2:-2] = torch.tensor(ATOM_METADATA, dtype=torch.float32) self.metadata_vocab.weight = nn.Parameter(vocab_weight, requires_grad=False) self.node_id = nn.Embedding(1, self.type_id_dim) self.edge_id = nn.Embedding(1, self.type_id_dim) self.embed_proj = nn.Linear(17 + 2 * self.k + self.type_id_dim, self.dim) self.graph = nn.Embedding(1, self.dim) self.distance = nn.Embedding(1, 17) self.distance_norm = nn.LayerNorm(17) self.blocks = nn.ModuleList() for _ in range(self.depth): self.blocks.append( ParallelBlock(self.dim, self.num_heads, self.mlp_ratio, self.dropout) ) def _expand_mask( self, mask: torch.Tensor, dtype: torch.dtype, device: torch.device, tgt_len: Optional[int] = None, ) -> torch.Tensor: bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask: torch.Tensor = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ).to(device)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, node_pe: torch.Tensor, edge_pe: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward function call for the transformer encoder.""" atom_metadata = self.metadata_vocab(input_ids) # (B, N, 17) node_ids = self.node_id( torch.zeros( node_pe.size(0), node_pe.size(1), dtype=torch.long, device=node_pe.device, ) ) edge_ids = self.edge_id( torch.zeros( edge_pe.size(0), edge_pe.size(1), dtype=torch.long, device=edge_pe.device, ) ) graph_tokens = self.graph( torch.zeros(node_pe.size(0), 1, dtype=torch.long, device=node_pe.device) ) nodes = torch.cat([atom_metadata, node_pe, node_ids], dim=-1) distance_embed = self.distance_norm( self.distance( torch.zeros( edge_pe.size(0), edge_pe.size(1), dtype=torch.long, device=edge_pe.device, ) ) * edge_pe[:, :, -1:] ) edges = torch.cat([distance_embed, edge_pe[:, :, :-1], edge_ids], dim=-1) input_embeds: torch.Tensor = self.embed_proj(torch.cat([nodes, edges], dim=1)) input_embeds = torch.cat([graph_tokens, input_embeds], dim=1) # convert attention mask from long into Boolean and add ones for graph token attention_mask = ( torch.cat( [ torch.ones( attention_mask.size(0), 1, dtype=torch.bool, device=attention_mask.device, ), attention_mask.bool(), ], dim=1, ) if attention_mask is not None else None ) for blk in self.blocks: if self.gradient_checkpointing and self.training: def create_custom_forward(module: Any) -> Callable[..., Any]: def custom_forward(*inputs: Any) -> Any: return module(*inputs) return custom_forward input_embeds = checkpoint( create_custom_forward(blk), input_embeds, attention_mask, ) else: input_embeds = blk(input_embeds, attention_mask) return input_embeds
[docs] class TransformerPreTrainedModel(PreTrainedModel): # type: ignore """Base class for all transformer models.""" config_class = TransformerConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["ParallelBlock"] def _set_gradient_checkpointing( self, module: nn.Module, value: bool = False ) -> None: if isinstance(module, (TransformerEncoder)): module.gradient_checkpointing = value
[docs] class TransformerModel(TransformerPreTrainedModel): """Transformer model for atom modeling.""" def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward function call for the transformer model.""" out: torch.Tensor = self.encoder(input_ids, coords, attention_mask) return out
[docs] class TransformerForMaskedAM(TransformerPreTrainedModel): """Transformer with an atom modeling head on top for masked atom modeling.""" def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.am_head = nn.Linear(config.dim, config.vocab_size, bias=False)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, labels: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Forward function call for the masked atom modeling model.""" hidden_states = self.encoder(input_ids, coords, attention_mask) logits = self.am_head(hidden_states[:, 1 : input_ids.size(1)]) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() logits, labels = logits.view(-1, self.config.vocab_size), labels.view(-1) loss = loss_fct(logits, labels) return loss, logits
[docs] class TransformerForCoordinateAM(TransformerPreTrainedModel): """Transformer with an atom coordinate head on top for coordinate denoising.""" def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.coords_head = nn.Linear(config.dim, 3)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, labels_coords: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Forward function call for the coordinate atom modeling model.""" hidden_states = self.encoder(input_ids, coords, attention_mask) coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)]) loss = None if labels_coords is not None: labels_coords = labels_coords.to(coords_pred.device) loss_fct = nn.L1Loss() loss = loss_fct(coords_pred, labels_coords) return loss, coords_pred
[docs] class InitialStructure2RelaxedStructure(TransformerPreTrainedModel): """Transformer with an coordinate head on top for relaxed structure prediction.""" def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.coords_head = nn.Linear(config.dim, 3)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, labels_coords: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Forward function call. Initial structure to relaxed structure model. """ hidden_states = self.encoder(input_ids, coords, attention_mask) coords_pred = self.coords_head(hidden_states[:, 1 : input_ids.size(1)]) loss = None if labels_coords is not None: labels_coords = labels_coords.to(coords_pred.device) loss_fct = nn.L1Loss() loss = loss_fct(coords_pred, labels_coords) return loss, coords_pred
[docs] class InitialStructure2RelaxedEnergy(TransformerPreTrainedModel): """Transformer with an energy head on top for relaxed energy prediction.""" def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.energy_norm = nn.LayerNorm(config.dim) self.energy_head = nn.Linear(config.dim, 1, bias=False)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, labels_energy: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """Forward function call for the initial structure to relaxed energy model.""" hidden_states = self.encoder(input_ids, coords, attention_mask) energy = self.energy_head(self.energy_norm(hidden_states[:, 0])).squeeze(-1) loss = None if labels_energy is not None: loss_fct = nn.L1Loss() loss = loss_fct(energy, labels_energy) return loss, energy
[docs] class InitialStructure2RelaxedStructureAndEnergy(TransformerPreTrainedModel): """Initial structure to relaxed structure and energy prediction model. Transformer with an coordinate and energy head on top for relaxed structure and energy prediction. """ def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.coords_head = nn.Linear(config.dim, 3) self.energy_norm = nn.LayerNorm(config.dim) self.energy_head = nn.Linear(config.dim, 1, bias=False)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, labels_coords: Optional[torch.Tensor] = None, labels_energy: Optional[torch.Tensor] = None, fixed: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """Forward function call. Initial structure to relaxed structure and energy model. """ hidden_states = self.encoder(input_ids, coords, attention_mask) coords_pred: torch.Tensor = self.coords_head( hidden_states[:, 1 : input_ids.size(1)] ) energy: torch.Tensor = self.energy_head( self.energy_norm(hidden_states[:, 0]) ).squeeze(-1) loss_coords = torch.tensor(0.0, device=input_ids.device) if labels_coords is not None: labels_coords = labels_coords.to(coords_pred.device) loss_fct = nn.L1Loss() loss_coords = loss_fct(coords_pred, labels_coords) loss_energy = torch.tensor(0.0, device=input_ids.device) if labels_energy is not None: loss_fct = nn.L1Loss() loss_energy = loss_fct(energy, labels_energy) loss = loss_coords + loss_energy return loss, (coords_pred, energy)
[docs] class Structure2EnergyAndForces(TransformerPreTrainedModel): """Structure to energy and forces prediction model. Transformer with an energy and forces head on top for energy and forces prediction. """ def __init__(self, config: TransformerConfig): super().__init__(config) self.config = config self.encoder = TransformerEncoder(config) self.force_norm = nn.LayerNorm(config.dim) self.force_head = nn.Linear(config.dim, 3) self.energy_norm = nn.LayerNorm(config.dim) self.formation_energy_head = nn.Linear(config.dim, 1, bias=False)
[docs] def forward( self, input_ids: torch.Tensor, coords: torch.Tensor, forces: Optional[torch.Tensor] = None, total_energy: Optional[torch.Tensor] = None, formation_energy: Optional[torch.Tensor] = None, has_formation_energy: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, node_pe: Optional[torch.Tensor] = None, edge_pe: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: """Forward function call for the structure to energy and forces model.""" hidden_states = self.encoder( input_ids, coords, attention_mask, node_pe, edge_pe ) formation_energy_pred: torch.Tensor = self.formation_energy_head( self.energy_norm(hidden_states[:, 0]) ).squeeze(-1) loss_formation_energy = torch.Tensor(0.0, device=input_ids.device) if formation_energy is not None: loss_fct = nn.L1Loss() loss_formation_energy = loss_fct( formation_energy_pred[has_formation_energy], formation_energy[has_formation_energy], ) forces_pred: torch.Tensor = self.force_head( self.force_norm(hidden_states[:, 1 : input_ids.size(1)]) ) loss_forces = torch.Tensor(0.0, device=input_ids.device) if forces is not None and attention_mask is not None: loss_fct = nn.L1Loss() loss_forces = loss_fct( forces_pred[attention_mask[:, 1 : input_ids.size(1)].bool()], forces[attention_mask[:, 1 : input_ids.size(1)].bool()], ) loss = loss_formation_energy + loss_forces return loss, ( formation_energy_pred, forces_pred, attention_mask.bool() if attention_mask is not None else attention_mask, )