program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3520.2.1"}, {"coremlc-version", "3520.2.1"}, {"coremltools-component-torch", "2.6.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.2.2"}, {"mldb_token", "mldb-wrxg5d7ao8"}})] { func main(tensor alreadyPrompted, tensor candidateInteractions, tensor candidate_risk, tensor component, tensor deviceContext, tensor forcedPromptRate, tensor isResolved, tensor parameterName, tensor riskLevel, tensor similarityScores, tensor tupleInteractions_alignment, tensor tupleInteractions_candidates, tensor tuples) { tensor var_18_perm_0 = const()[name = tensor("op_18_perm_0"), val = tensor([1, 0])]; tensor alignments_1_begin_0 = const()[name = tensor("alignments_1_begin_0"), val = tensor([1, 0])]; tensor alignments_1_end_0 = const()[name = tensor("alignments_1_end_0"), val = tensor([2, 1000])]; tensor alignments_1_end_mask_0 = const()[name = tensor("alignments_1_end_mask_0"), val = tensor([false, true])]; tensor alignments_1_squeeze_mask_0 = const()[name = tensor("alignments_1_squeeze_mask_0"), val = tensor([true, false])]; tensor var_18 = transpose(perm = var_18_perm_0, x = candidateInteractions)[name = tensor("transpose_48")]; tensor alignments_1 = slice_by_index(begin = alignments_1_begin_0, end = alignments_1_end_0, end_mask = alignments_1_end_mask_0, squeeze_mask = alignments_1_squeeze_mask_0, x = var_18)[name = tensor("alignments_1")]; tensor var_23_axes_0 = const()[name = tensor("op_23_axes_0"), val = tensor([1])]; tensor var_23 = expand_dims(axes = var_23_axes_0, x = alignments_1)[name = tensor("op_23")]; tensor var_24 = const()[name = tensor("op_24"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor var_26 = sub(x = var_23, y = var_24)[name = tensor("op_26")]; tensor var_27 = abs(x = var_26)[name = tensor("op_27")]; tensor var_28 = const()[name = tensor("op_28"), val = tensor(0x1.0624dep-10)]; tensor var_29 = less(x = var_27, y = var_28)[name = tensor("op_29")]; tensor var_29_promoted_dtype_0 = const()[name = tensor("op_29_promoted_dtype_0"), val = tensor("fp32")]; tensor transpose_0 = const()[name = tensor("transpose_0"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(256)))]; tensor new_alignments_bias_0 = const()[name = tensor("new_alignments_bias_0"), val = tensor([0x0p+0])]; tensor var_29_promoted = cast(dtype = var_29_promoted_dtype_0, x = var_29)[name = tensor("cast_261")]; tensor new_alignments = linear(bias = new_alignments_bias_0, weight = transpose_0, x = var_29_promoted)[name = tensor("new_alignments")]; tensor var_38 = squeeze(x = new_alignments)[name = tensor("op_38")]; tensor shape_15 = const()[name = tensor("shape_15"), val = tensor([8, 1000])]; tensor slice_by_index_24 = const()[name = tensor("slice_by_index_24"), val = tensor([1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119, 1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143, 1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151, 1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163, 1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175, 1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211, 1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223, 1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235, 1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279, 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287, 1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295, 1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311, 1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319, 1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327, 1328, 1329, 1330, 1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359, 1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367, 1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375, 1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383, 1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391, 1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400, 1401, 1402, 1403, 1404, 1405, 1406, 1407, 1408, 1409, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1417, 1418, 1419, 1420, 1421, 1422, 1423, 1424, 1425, 1426, 1427, 1428, 1429, 1430, 1431, 1432, 1433, 1434, 1435, 1436, 1437, 1438, 1439, 1440, 1441, 1442, 1443, 1444, 1445, 1446, 1447, 1448, 1449, 1450, 1451, 1452, 1453, 1454, 1455, 1456, 1457, 1458, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1469, 1470, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1480, 1481, 1482, 1483, 1484, 1485, 1486, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1494, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1502, 1503, 1504, 1505, 1506, 1507, 1508, 1509, 1510, 1511, 1512, 1513, 1514, 1515, 1516, 1517, 1518, 1519, 1520, 1521, 1522, 1523, 1524, 1525, 1526, 1527, 1528, 1529, 1530, 1531, 1532, 1533, 1534, 1535, 1536, 1537, 1538, 1539, 1540, 1541, 1542, 1543, 1544, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1553, 1554, 1555, 1556, 1557, 1558, 1559, 1560, 1561, 1562, 1563, 1564, 1565, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1573, 1574, 1575, 1576, 1577, 1578, 1579, 1580, 1581, 1582, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1591, 1592, 1593, 1594, 1595, 1596, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1604, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1616, 1617, 1618, 1619, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1628, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1638, 1639, 1640, 1641, 1642, 1643, 1644, 1645, 1646, 1647, 1648, 1649, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1659, 1660, 1661, 1662, 1663, 1664, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1672, 1673, 1674, 1675, 1676, 1677, 1678, 1679, 1680, 1681, 1682, 1683, 1684, 1685, 1686, 1687, 1688, 1689, 1690, 1691, 1692, 1693, 1694, 1695, 1696, 1697, 1698, 1699, 1700, 1701, 1702, 1703, 1704, 1705, 1706, 1707, 1708, 1709, 1710, 1711, 1712, 1713, 1714, 1715, 1716, 1717, 1718, 1719, 1720, 1721, 1722, 1723, 1724, 1725, 1726, 1727, 1728, 1729, 1730, 1731, 1732, 1733, 1734, 1735, 1736, 1737, 1738, 1739, 1740, 1741, 1742, 1743, 1744, 1745, 1746, 1747, 1748, 1749, 1750, 1751, 1752, 1753, 1754, 1755, 1756, 1757, 1758, 1759, 1760, 1761, 1762, 1763, 1764, 1765, 1766, 1767, 1768, 1769, 1770, 1771, 1772, 1773, 1774, 1775, 1776, 1777, 1778, 1779, 1780, 1781, 1782, 1783, 1784, 1785, 1786, 1787, 1788, 1789, 1790, 1791, 1792, 1793, 1794, 1795, 1796, 1797, 1798, 1799, 1800, 1801, 1802, 1803, 1804, 1805, 1806, 1807, 1808, 1809, 1810, 1811, 1812, 1813, 1814, 1815, 1816, 1817, 1818, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827, 1828, 1829, 1830, 1831, 1832, 1833, 1834, 1835, 1836, 1837, 1838, 1839, 1840, 1841, 1842, 1843, 1844, 1845, 1846, 1847, 1848, 1849, 1850, 1851, 1852, 1853, 1854, 1855, 1856, 1857, 1858, 1859, 1860, 1861, 1862, 1863, 1864, 1865, 1866, 1867, 1868, 1869, 1870, 1871, 1872, 1873, 1874, 1875, 1876, 1877, 1878, 1879, 1880, 1881, 1882, 1883, 1884, 1885, 1886, 1887, 1888, 1889, 1890, 1891, 1892, 1893, 1894, 1895, 1896, 1897, 1898, 1899, 1900, 1901, 1902, 1903, 1904, 1905, 1906, 1907, 1908, 1909, 1910, 1911, 1912, 1913, 1914, 1915, 1916, 1917, 1918, 1919, 1920, 1921, 1922, 1923, 1924, 1925, 1926, 1927, 1928, 1929, 1930, 1931, 1932, 1933, 1934, 1935, 1936, 1937, 1938, 1939, 1940, 1941, 1942, 1943, 1944, 1945, 1946, 1947, 1948, 1949, 1950, 1951, 1952, 1953, 1954, 1955, 1956, 1957, 1958, 1959, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967, 1968, 1969, 1970, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999])]; tensor reshape_3_shape_0 = const()[name = tensor("reshape_3_shape_0"), val = tensor([-1])]; tensor reshape_3 = reshape(shape = reshape_3_shape_0, x = var_18)[name = tensor("reshape_3")]; tensor scatter_0_mode_0 = const()[name = tensor("scatter_0_mode_0"), val = tensor("update")]; tensor scatter_0_axis_0 = const()[name = tensor("scatter_0_axis_0"), val = tensor(0)]; tensor scatter_0 = scatter(axis = scatter_0_axis_0, data = reshape_3, indices = slice_by_index_24, mode = scatter_0_mode_0, updates = var_38)[name = tensor("scatter_0")]; tensor reshape_4 = reshape(shape = shape_15, x = scatter_0)[name = tensor("reshape_4")]; tensor var_51 = const()[name = tensor("op_51"), val = tensor(0)]; tensor flat_1_interleave_0 = const()[name = tensor("flat_1_interleave_0"), val = tensor(false)]; tensor flat_1 = concat(axis = var_51, interleave = flat_1_interleave_0, values = forcedPromptRate)[name = tensor("flat_1")]; tensor var_53 = const()[name = tensor("op_53"), val = tensor(0x1.eb851ep-6)]; tensor var_54_div = floor_div(x = flat_1, y = var_53)[name = tensor("op_54_div")]; tensor var_54_div_scaled = mul(x = var_54_div, y = var_53)[name = tensor("op_54_div_scaled")]; tensor var_54 = sub(x = flat_1, y = var_54_div_scaled)[name = tensor("op_54")]; tensor var_55_promoted = const()[name = tensor("op_55_promoted"), val = tensor(0x1.9p+6)]; tensor out_1 = mul(x = var_54, y = var_55_promoted)[name = tensor("out_1")]; tensor var_57 = const()[name = tensor("op_57"), val = tensor(0x1.1eb852p-4)]; tensor var_58_div = floor_div(x = out_1, y = var_57)[name = tensor("op_58_div")]; tensor var_58_div_scaled = mul(x = var_58_div, y = var_57)[name = tensor("op_58_div_scaled")]; tensor var_58 = sub(x = out_1, y = var_58_div_scaled)[name = tensor("op_58")]; tensor var_59_promoted = const()[name = tensor("op_59_promoted"), val = tensor(0x1.9p+6)]; tensor out_3 = mul(x = var_58, y = var_59_promoted)[name = tensor("out_3")]; tensor out_5_keep_dims_0 = const()[name = tensor("out_5_keep_dims_0"), val = tensor(false)]; tensor out_5 = reduce_sum(keep_dims = out_5_keep_dims_0, x = out_3)[name = tensor("out_5")]; tensor var_63 = const()[name = tensor("op_63"), val = tensor(0x1.851eb8p-3)]; tensor var_64_div = floor_div(x = out_5, y = var_63)[name = tensor("op_64_div")]; tensor var_64_div_scaled = mul(x = var_64_div, y = var_63)[name = tensor("op_64_div_scaled")]; tensor var_64 = sub(x = out_5, y = var_64_div_scaled)[name = tensor("op_64")]; tensor var_65_promoted = const()[name = tensor("op_65_promoted"), val = tensor(0x1.9p+6)]; tensor out_7 = mul(x = var_64, y = var_65_promoted)[name = tensor("out_7")]; tensor var_67 = const()[name = tensor("op_67"), val = tensor(0x1.b645a2p-4)]; tensor var_68_div = floor_div(x = out_7, y = var_67)[name = tensor("op_68_div")]; tensor var_68_div_scaled = mul(x = var_68_div, y = var_67)[name = tensor("op_68_div_scaled")]; tensor var_68 = sub(x = out_7, y = var_68_div_scaled)[name = tensor("op_68")]; tensor var_69_promoted = const()[name = tensor("op_69_promoted"), val = tensor(0x1.f4p+9)]; tensor var_70 = mul(x = var_68, y = var_69_promoted)[name = tensor("op_70")]; tensor _inversed_72_y_0 = const()[name = tensor("_inversed_72_y_0"), val = tensor(0x1.323e34p-7)]; tensor _inversed_72 = mul(x = var_70, y = _inversed_72_y_0)[name = tensor("_inversed_72")]; tensor var_73_promoted = const()[name = tensor("op_73_promoted"), val = tensor(0x1.9p+6)]; tensor var_74 = mul(x = _inversed_72, y = var_73_promoted)[name = tensor("op_74")]; tensor var_77_begin_0 = const()[name = tensor("op_77_begin_0"), val = tensor([0, 0])]; tensor var_77_end_0 = const()[name = tensor("op_77_end_0"), val = tensor([1, 2])]; tensor var_77_end_mask_0 = const()[name = tensor("op_77_end_mask_0"), val = tensor([false, true])]; tensor var_77_squeeze_mask_0 = const()[name = tensor("op_77_squeeze_mask_0"), val = tensor([true, false])]; tensor var_77 = slice_by_index(begin = var_77_begin_0, end = var_77_end_0, end_mask = var_77_end_mask_0, squeeze_mask = var_77_squeeze_mask_0, x = riskLevel)[name = tensor("op_77")]; tensor var_80_begin_0 = const()[name = tensor("op_80_begin_0"), val = tensor([0])]; tensor var_80_end_0 = const()[name = tensor("op_80_end_0"), val = tensor([1])]; tensor var_80_end_mask_0 = const()[name = tensor("op_80_end_mask_0"), val = tensor([false])]; tensor var_80_squeeze_mask_0 = const()[name = tensor("op_80_squeeze_mask_0"), val = tensor([true])]; tensor var_80 = slice_by_index(begin = var_80_begin_0, end = var_80_end_0, end_mask = var_80_end_mask_0, squeeze_mask = var_80_squeeze_mask_0, x = var_77)[name = tensor("op_80")]; tensor var_81_promoted = const()[name = tensor("op_81_promoted"), val = tensor(0x1.4p+2)]; tensor var_82 = equal(x = var_80, y = var_81_promoted)[name = tensor("op_82")]; tensor var_83 = const()[name = tensor("op_83"), val = tensor(1000)]; tensor var_82_promoted_dtype_0 = const()[name = tensor("op_82_promoted_dtype_0"), val = tensor("int32")]; tensor var_82_promoted = cast(dtype = var_82_promoted_dtype_0, x = var_82)[name = tensor("cast_260")]; tensor var_84 = mul(x = var_82_promoted, y = var_83)[name = tensor("op_84")]; tensor var_84_promoted_dtype_0 = const()[name = tensor("op_84_promoted_dtype_0"), val = tensor("fp32")]; tensor var_84_promoted = cast(dtype = var_84_promoted_dtype_0, x = var_84)[name = tensor("cast_259")]; tensor var_86 = add(x = var_84_promoted, y = var_74)[name = tensor("op_86")]; tensor var_88_promoted = const()[name = tensor("op_88_promoted"), val = tensor(0x1p+0)]; tensor i_i = add(x = forcedPromptRate, y = var_88_promoted)[name = tensor("i_i")]; tensor var_94 = const()[name = tensor("op_94"), val = tensor(0)]; tensor flat_interleave_0 = const()[name = tensor("flat_interleave_0"), val = tensor(false)]; tensor flat = concat(axis = var_94, interleave = flat_interleave_0, values = i_i)[name = tensor("flat")]; tensor var_96 = const()[name = tensor("op_96"), val = tensor(0x1.eb851ep-6)]; tensor var_97_div = floor_div(x = flat, y = var_96)[name = tensor("op_97_div")]; tensor var_97_div_scaled = mul(x = var_97_div, y = var_96)[name = tensor("op_97_div_scaled")]; tensor var_97 = sub(x = flat, y = var_97_div_scaled)[name = tensor("op_97")]; tensor var_98_promoted = const()[name = tensor("op_98_promoted"), val = tensor(0x1.9p+6)]; tensor out_9 = mul(x = var_97, y = var_98_promoted)[name = tensor("out_9")]; tensor var_100 = const()[name = tensor("op_100"), val = tensor(0x1.1eb852p-4)]; tensor var_101_div = floor_div(x = out_9, y = var_100)[name = tensor("op_101_div")]; tensor var_101_div_scaled = mul(x = var_101_div, y = var_100)[name = tensor("op_101_div_scaled")]; tensor var_101 = sub(x = out_9, y = var_101_div_scaled)[name = tensor("op_101")]; tensor var_102_promoted = const()[name = tensor("op_102_promoted"), val = tensor(0x1.9p+6)]; tensor out_11 = mul(x = var_101, y = var_102_promoted)[name = tensor("out_11")]; tensor out_13_keep_dims_0 = const()[name = tensor("out_13_keep_dims_0"), val = tensor(false)]; tensor out_13 = reduce_sum(keep_dims = out_13_keep_dims_0, x = out_11)[name = tensor("out_13")]; tensor var_106 = const()[name = tensor("op_106"), val = tensor(0x1.851eb8p-3)]; tensor var_107_div = floor_div(x = out_13, y = var_106)[name = tensor("op_107_div")]; tensor var_107_div_scaled = mul(x = var_107_div, y = var_106)[name = tensor("op_107_div_scaled")]; tensor var_107 = sub(x = out_13, y = var_107_div_scaled)[name = tensor("op_107")]; tensor var_108_promoted = const()[name = tensor("op_108_promoted"), val = tensor(0x1.9p+6)]; tensor out_15 = mul(x = var_107, y = var_108_promoted)[name = tensor("out_15")]; tensor var_110 = const()[name = tensor("op_110"), val = tensor(0x1.b645a2p-4)]; tensor var_111_div = floor_div(x = out_15, y = var_110)[name = tensor("op_111_div")]; tensor var_111_div_scaled = mul(x = var_111_div, y = var_110)[name = tensor("op_111_div_scaled")]; tensor var_111 = sub(x = out_15, y = var_111_div_scaled)[name = tensor("op_111")]; tensor var_112_promoted = const()[name = tensor("op_112_promoted"), val = tensor(0x1.f4p+9)]; tensor var_113 = mul(x = var_111, y = var_112_promoted)[name = tensor("op_113")]; tensor _inversed_115_y_0 = const()[name = tensor("_inversed_115_y_0"), val = tensor(0x1.323e34p-7)]; tensor _inversed_115 = mul(x = var_113, y = _inversed_115_y_0)[name = tensor("_inversed_115")]; tensor var_116_promoted = const()[name = tensor("op_116_promoted"), val = tensor(0x1.9p+6)]; tensor random_seed = mul(x = _inversed_115, y = var_116_promoted)[name = tensor("random_seed")]; tensor var_120_begin_0 = const()[name = tensor("op_120_begin_0"), val = tensor([1])]; tensor var_120_end_0 = const()[name = tensor("op_120_end_0"), val = tensor([2])]; tensor var_120_end_mask_0 = const()[name = tensor("op_120_end_mask_0"), val = tensor([false])]; tensor var_120_squeeze_mask_0 = const()[name = tensor("op_120_squeeze_mask_0"), val = tensor([true])]; tensor var_120 = slice_by_index(begin = var_120_begin_0, end = var_120_end_0, end_mask = var_120_end_mask_0, squeeze_mask = var_120_squeeze_mask_0, x = forcedPromptRate)[name = tensor("op_120")]; tensor var_121 = less(x = var_86, y = var_120)[name = tensor("op_121")]; tensor var_121_promoted_dtype_0 = const()[name = tensor("op_121_promoted_dtype_0"), val = tensor("fp32")]; tensor var_130_begin_0 = const()[name = tensor("op_130_begin_0"), val = tensor([2])]; tensor var_130_end_0 = const()[name = tensor("op_130_end_0"), val = tensor([3])]; tensor var_130_end_mask_0 = const()[name = tensor("op_130_end_mask_0"), val = tensor([false])]; tensor var_130_squeeze_mask_0 = const()[name = tensor("op_130_squeeze_mask_0"), val = tensor([true])]; tensor var_130 = slice_by_index(begin = var_130_begin_0, end = var_130_end_0, end_mask = var_130_end_mask_0, squeeze_mask = var_130_squeeze_mask_0, x = forcedPromptRate)[name = tensor("op_130")]; tensor var_132 = add(x = var_120, y = var_130)[name = tensor("op_132")]; tensor var_133 = less(x = var_86, y = var_132)[name = tensor("op_133")]; tensor var_134_promoted = const()[name = tensor("op_134_promoted"), val = tensor(0x1p+0)]; tensor var_121_promoted = cast(dtype = var_121_promoted_dtype_0, x = var_121)[name = tensor("cast_258")]; tensor var_136 = sub(x = var_134_promoted, y = var_121_promoted)[name = tensor("op_136")]; tensor var_133_promoted_dtype_0 = const()[name = tensor("op_133_promoted_dtype_0"), val = tensor("fp32")]; tensor var_133_promoted = cast(dtype = var_133_promoted_dtype_0, x = var_133)[name = tensor("cast_257")]; tensor forced_1 = mul(x = var_133_promoted, y = var_136)[name = tensor("forced_1")]; tensor var_148_begin_0 = const()[name = tensor("op_148_begin_0"), val = tensor([3])]; tensor var_148_end_0 = const()[name = tensor("op_148_end_0"), val = tensor([4])]; tensor var_148_end_mask_0 = const()[name = tensor("op_148_end_mask_0"), val = tensor([false])]; tensor var_148_squeeze_mask_0 = const()[name = tensor("op_148_squeeze_mask_0"), val = tensor([true])]; tensor var_148 = slice_by_index(begin = var_148_begin_0, end = var_148_end_0, end_mask = var_148_end_mask_0, squeeze_mask = var_148_squeeze_mask_0, x = forcedPromptRate)[name = tensor("op_148")]; tensor var_150 = add(x = var_132, y = var_148)[name = tensor("op_150")]; tensor var_151 = less(x = var_86, y = var_150)[name = tensor("op_151")]; tensor var_151_promoted_dtype_0 = const()[name = tensor("op_151_promoted_dtype_0"), val = tensor("fp32")]; tensor var_151_promoted = cast(dtype = var_151_promoted_dtype_0, x = var_151)[name = tensor("cast_256")]; tensor var_155 = mul(x = var_151_promoted, y = var_136)[name = tensor("op_155")]; tensor var_156_promoted = const()[name = tensor("op_156_promoted"), val = tensor(0x1p+0)]; tensor var_158 = sub(x = var_156_promoted, y = forced_1)[name = tensor("op_158")]; tensor forced_parameter_confirm = mul(x = var_155, y = var_158)[name = tensor("forced_parameter_confirm")]; tensor var_162_begin_0 = const()[name = tensor("op_162_begin_0"), val = tensor([0, 0, 0])]; tensor var_162_end_0 = const()[name = tensor("op_162_end_0"), val = tensor([1, 50, 15])]; tensor var_162_end_mask_0 = const()[name = tensor("op_162_end_mask_0"), val = tensor([false, true, true])]; tensor var_162_squeeze_mask_0 = const()[name = tensor("op_162_squeeze_mask_0"), val = tensor([true, false, false])]; tensor var_162 = slice_by_index(begin = var_162_begin_0, end = var_162_end_0, end_mask = var_162_end_mask_0, squeeze_mask = var_162_squeeze_mask_0, x = tuples)[name = tensor("op_162")]; tensor var_163 = const()[name = tensor("op_163"), val = tensor(-0x1p-1)]; tensor var_164 = greater(x = var_162, y = var_163)[name = tensor("op_164")]; tensor var_164_promoted_dtype_0 = const()[name = tensor("op_164_promoted_dtype_0"), val = tensor("int32")]; tensor not_padded_mask_promoted_dtype_0 = const()[name = tensor("not_padded_mask_promoted_dtype_0"), val = tensor("fp32")]; tensor var_164_to_fp32 = cast(dtype = not_padded_mask_promoted_dtype_0, x = var_164)[name = tensor("cast_254")]; tensor x_3 = mul(x = var_162, y = var_164_to_fp32)[name = tensor("x_3")]; tensor var_172_axes_0 = const()[name = tensor("op_172_axes_0"), val = tensor([1])]; tensor var_172_keep_dims_0 = const()[name = tensor("op_172_keep_dims_0"), val = tensor(false)]; tensor var_164_promoted = cast(dtype = var_164_promoted_dtype_0, x = var_164)[name = tensor("cast_255")]; tensor var_172 = reduce_sum(axes = var_172_axes_0, keep_dims = var_172_keep_dims_0, x = var_164_promoted)[name = tensor("op_172")]; tensor var_173 = const()[name = tensor("op_173"), val = tensor(0x0p+0)]; tensor var_172_promoted_dtype_0 = const()[name = tensor("op_172_promoted_dtype_0"), val = tensor("fp32")]; tensor var_172_promoted = cast(dtype = var_172_promoted_dtype_0, x = var_172)[name = tensor("cast_253")]; tensor var_174 = greater(x = var_172_promoted, y = var_173)[name = tensor("op_174")]; tensor var_174_promoted_dtype_0 = const()[name = tensor("op_174_promoted_dtype_0"), val = tensor("fp32")]; tensor var_182_axes_0 = const()[name = tensor("op_182_axes_0"), val = tensor([0])]; tensor var_182_keep_dims_0 = const()[name = tensor("op_182_keep_dims_0"), val = tensor(false)]; tensor var_182 = reduce_sum(axes = var_182_axes_0, keep_dims = var_182_keep_dims_0, x = var_164_promoted)[name = tensor("op_182")]; tensor var_183 = const()[name = tensor("op_183"), val = tensor(0x0p+0)]; tensor var_182_promoted_dtype_0 = const()[name = tensor("op_182_promoted_dtype_0"), val = tensor("fp32")]; tensor var_182_promoted = cast(dtype = var_182_promoted_dtype_0, x = var_182)[name = tensor("cast_251")]; tensor var_184 = greater(x = var_182_promoted, y = var_183)[name = tensor("op_184")]; tensor var_184_promoted_dtype_0 = const()[name = tensor("op_184_promoted_dtype_0"), val = tensor("fp32")]; tensor var_193_begin_0 = const()[name = tensor("op_193_begin_0"), val = tensor([4, 0])]; tensor var_193_end_0 = const()[name = tensor("op_193_end_0"), val = tensor([5, 1000])]; tensor var_193_end_mask_0 = const()[name = tensor("op_193_end_mask_0"), val = tensor([false, true])]; tensor var_193_squeeze_mask_0 = const()[name = tensor("op_193_squeeze_mask_0"), val = tensor([true, false])]; tensor var_193 = slice_by_index(begin = var_193_begin_0, end = var_193_end_0, end_mask = var_193_end_mask_0, squeeze_mask = var_193_squeeze_mask_0, x = reshape_4)[name = tensor("op_193")]; tensor var_194 = const()[name = tensor("op_194"), val = tensor(-0x1p-1)]; tensor var_195 = less(x = var_193, y = var_194)[name = tensor("op_195")]; tensor var_195_promoted_dtype_0 = const()[name = tensor("op_195_promoted_dtype_0"), val = tensor("fp32")]; tensor var_203 = const()[name = tensor("op_203"), val = tensor(0x1.0609bep+3)]; tensor var_204 = sub(x = var_193, y = var_203)[name = tensor("op_204")]; tensor cand_time_1 = exp(x = var_204)[name = tensor("cand_time_1")]; tensor var_206_promoted = const()[name = tensor("op_206_promoted"), val = tensor(0x1p+0)]; tensor var_195_promoted = cast(dtype = var_195_promoted_dtype_0, x = var_195)[name = tensor("cast_249")]; tensor var_208 = sub(x = var_206_promoted, y = var_195_promoted)[name = tensor("op_208")]; tensor var_209 = const()[name = tensor("op_209"), val = tensor(0x1.0c1524p-2)]; tensor var_210 = mul(x = cand_time_1, y = var_209)[name = tensor("op_210")]; tensor var_211 = cos(x = var_210)[name = tensor("op_211")]; tensor var_212_promoted = const()[name = tensor("op_212_promoted"), val = tensor(-0x1p+0)]; tensor var_213 = mul(x = var_211, y = var_212_promoted)[name = tensor("op_213")]; tensor var_214 = const()[name = tensor("op_214"), val = tensor(0x1p-1)]; tensor var_215 = mul(x = var_213, y = var_214)[name = tensor("op_215")]; tensor var_217 = const()[name = tensor("op_217"), val = tensor(0x1p-1)]; tensor var_218 = add(x = var_215, y = var_217)[name = tensor("op_218")]; tensor var_219 = mul(x = var_208, y = var_218)[name = tensor("op_219")]; tensor var_220_promoted = const()[name = tensor("op_220_promoted"), val = tensor(-0x1p+0)]; tensor var_221 = mul(x = var_195_promoted, y = var_220_promoted)[name = tensor("op_221")]; tensor cand_freq_1 = add(x = var_219, y = var_221)[name = tensor("cand_freq_1")]; tensor var_225_axes_0 = const()[name = tensor("op_225_axes_0"), val = tensor([0])]; tensor var_225 = expand_dims(axes = var_225_axes_0, x = cand_freq_1)[name = tensor("op_225")]; tensor var_227 = const()[name = tensor("op_227"), val = tensor(0)]; tensor var_228_interleave_0 = const()[name = tensor("op_228_interleave_0"), val = tensor(false)]; tensor var_228 = concat(axis = var_227, interleave = var_228_interleave_0, values = (reshape_4, var_225))[name = tensor("op_228")]; tensor var_237_begin_0 = const()[name = tensor("op_237_begin_0"), val = tensor([4, 0])]; tensor var_237_end_0 = const()[name = tensor("op_237_end_0"), val = tensor([5, 1000])]; tensor var_237_end_mask_0 = const()[name = tensor("op_237_end_mask_0"), val = tensor([false, true])]; tensor var_237_squeeze_mask_0 = const()[name = tensor("op_237_squeeze_mask_0"), val = tensor([true, false])]; tensor var_237 = slice_by_index(begin = var_237_begin_0, end = var_237_end_0, end_mask = var_237_end_mask_0, squeeze_mask = var_237_squeeze_mask_0, x = var_228)[name = tensor("op_237")]; tensor var_238 = const()[name = tensor("op_238"), val = tensor(-0x1p-1)]; tensor var_239 = less(x = var_237, y = var_238)[name = tensor("op_239")]; tensor var_239_promoted_dtype_0 = const()[name = tensor("op_239_promoted_dtype_0"), val = tensor("fp32")]; tensor var_247 = const()[name = tensor("op_247"), val = tensor(0x1.0609bep+3)]; tensor var_248 = sub(x = var_237, y = var_247)[name = tensor("op_248")]; tensor cand_time_3 = exp(x = var_248)[name = tensor("cand_time_3")]; tensor var_250_promoted = const()[name = tensor("op_250_promoted"), val = tensor(0x1p+0)]; tensor var_239_promoted = cast(dtype = var_239_promoted_dtype_0, x = var_239)[name = tensor("cast_248")]; tensor var_252 = sub(x = var_250_promoted, y = var_239_promoted)[name = tensor("op_252")]; tensor var_253 = const()[name = tensor("op_253"), val = tensor(0x1.32614ep-5)]; tensor var_254 = mul(x = cand_time_3, y = var_253)[name = tensor("op_254")]; tensor var_255 = cos(x = var_254)[name = tensor("op_255")]; tensor var_256_promoted = const()[name = tensor("op_256_promoted"), val = tensor(-0x1p+0)]; tensor var_257 = mul(x = var_255, y = var_256_promoted)[name = tensor("op_257")]; tensor var_258 = const()[name = tensor("op_258"), val = tensor(0x1p-1)]; tensor var_259 = mul(x = var_257, y = var_258)[name = tensor("op_259")]; tensor var_261 = const()[name = tensor("op_261"), val = tensor(0x1p-1)]; tensor var_262 = add(x = var_259, y = var_261)[name = tensor("op_262")]; tensor var_263 = mul(x = var_252, y = var_262)[name = tensor("op_263")]; tensor var_264_promoted = const()[name = tensor("op_264_promoted"), val = tensor(-0x1p+0)]; tensor var_265 = mul(x = var_239_promoted, y = var_264_promoted)[name = tensor("op_265")]; tensor cand_freq_3 = add(x = var_263, y = var_265)[name = tensor("cand_freq_3")]; tensor var_269_axes_0 = const()[name = tensor("op_269_axes_0"), val = tensor([0])]; tensor var_269 = expand_dims(axes = var_269_axes_0, x = cand_freq_3)[name = tensor("op_269")]; tensor var_271 = const()[name = tensor("op_271"), val = tensor(0)]; tensor var_272_interleave_0 = const()[name = tensor("op_272_interleave_0"), val = tensor(false)]; tensor var_272 = concat(axis = var_271, interleave = var_272_interleave_0, values = (var_228, var_269))[name = tensor("op_272")]; tensor candidate_by_column_5_perm_0 = const()[name = tensor("candidate_by_column_5_perm_0"), val = tensor([1, 0])]; tensor var_281_begin_0 = const()[name = tensor("op_281_begin_0"), val = tensor([3, 0])]; tensor var_281_end_0 = const()[name = tensor("op_281_end_0"), val = tensor([4, 1000])]; tensor var_281_end_mask_0 = const()[name = tensor("op_281_end_mask_0"), val = tensor([false, true])]; tensor var_281_squeeze_mask_0 = const()[name = tensor("op_281_squeeze_mask_0"), val = tensor([true, false])]; tensor candidate_by_column_5 = transpose(perm = candidate_by_column_5_perm_0, x = tupleInteractions_alignment)[name = tensor("transpose_47")]; tensor var_281 = slice_by_index(begin = var_281_begin_0, end = var_281_end_0, end_mask = var_281_end_mask_0, squeeze_mask = var_281_squeeze_mask_0, x = candidate_by_column_5)[name = tensor("op_281")]; tensor var_282 = const()[name = tensor("op_282"), val = tensor(-0x1p-1)]; tensor var_283 = less(x = var_281, y = var_282)[name = tensor("op_283")]; tensor var_283_promoted_dtype_0 = const()[name = tensor("op_283_promoted_dtype_0"), val = tensor("fp32")]; tensor var_291 = const()[name = tensor("op_291"), val = tensor(0x1.0609bep+3)]; tensor var_292 = sub(x = var_281, y = var_291)[name = tensor("op_292")]; tensor cand_time_5 = exp(x = var_292)[name = tensor("cand_time_5")]; tensor var_294_promoted = const()[name = tensor("op_294_promoted"), val = tensor(0x1p+0)]; tensor var_283_promoted = cast(dtype = var_283_promoted_dtype_0, x = var_283)[name = tensor("cast_247")]; tensor var_296 = sub(x = var_294_promoted, y = var_283_promoted)[name = tensor("op_296")]; tensor var_297 = const()[name = tensor("op_297"), val = tensor(0x1.0c1524p-2)]; tensor var_298 = mul(x = cand_time_5, y = var_297)[name = tensor("op_298")]; tensor var_299 = cos(x = var_298)[name = tensor("op_299")]; tensor var_300_promoted = const()[name = tensor("op_300_promoted"), val = tensor(-0x1p+0)]; tensor var_301 = mul(x = var_299, y = var_300_promoted)[name = tensor("op_301")]; tensor var_302 = const()[name = tensor("op_302"), val = tensor(0x1p-1)]; tensor var_303 = mul(x = var_301, y = var_302)[name = tensor("op_303")]; tensor var_305 = const()[name = tensor("op_305"), val = tensor(0x1p-1)]; tensor var_306 = add(x = var_303, y = var_305)[name = tensor("op_306")]; tensor var_307 = mul(x = var_296, y = var_306)[name = tensor("op_307")]; tensor var_308_promoted = const()[name = tensor("op_308_promoted"), val = tensor(-0x1p+0)]; tensor var_309 = mul(x = var_283_promoted, y = var_308_promoted)[name = tensor("op_309")]; tensor cand_freq_5 = add(x = var_307, y = var_309)[name = tensor("cand_freq_5")]; tensor var_313_axes_0 = const()[name = tensor("op_313_axes_0"), val = tensor([0])]; tensor var_313 = expand_dims(axes = var_313_axes_0, x = cand_freq_5)[name = tensor("op_313")]; tensor var_315 = const()[name = tensor("op_315"), val = tensor(0)]; tensor var_316_interleave_0 = const()[name = tensor("op_316_interleave_0"), val = tensor(false)]; tensor var_316 = concat(axis = var_315, interleave = var_316_interleave_0, values = (candidate_by_column_5, var_313))[name = tensor("op_316")]; tensor var_325_begin_0 = const()[name = tensor("op_325_begin_0"), val = tensor([3, 0])]; tensor var_325_end_0 = const()[name = tensor("op_325_end_0"), val = tensor([4, 1000])]; tensor var_325_end_mask_0 = const()[name = tensor("op_325_end_mask_0"), val = tensor([false, true])]; tensor var_325_squeeze_mask_0 = const()[name = tensor("op_325_squeeze_mask_0"), val = tensor([true, false])]; tensor var_325 = slice_by_index(begin = var_325_begin_0, end = var_325_end_0, end_mask = var_325_end_mask_0, squeeze_mask = var_325_squeeze_mask_0, x = var_316)[name = tensor("op_325")]; tensor var_326 = const()[name = tensor("op_326"), val = tensor(-0x1p-1)]; tensor var_327 = less(x = var_325, y = var_326)[name = tensor("op_327")]; tensor var_327_promoted_dtype_0 = const()[name = tensor("op_327_promoted_dtype_0"), val = tensor("fp32")]; tensor var_335 = const()[name = tensor("op_335"), val = tensor(0x1.0609bep+3)]; tensor var_336 = sub(x = var_325, y = var_335)[name = tensor("op_336")]; tensor cand_time = exp(x = var_336)[name = tensor("cand_time")]; tensor var_338_promoted = const()[name = tensor("op_338_promoted"), val = tensor(0x1p+0)]; tensor var_327_promoted = cast(dtype = var_327_promoted_dtype_0, x = var_327)[name = tensor("cast_246")]; tensor var_340 = sub(x = var_338_promoted, y = var_327_promoted)[name = tensor("op_340")]; tensor var_341 = const()[name = tensor("op_341"), val = tensor(0x1.32614ep-5)]; tensor var_342 = mul(x = cand_time, y = var_341)[name = tensor("op_342")]; tensor var_343 = cos(x = var_342)[name = tensor("op_343")]; tensor var_344_promoted = const()[name = tensor("op_344_promoted"), val = tensor(-0x1p+0)]; tensor var_345 = mul(x = var_343, y = var_344_promoted)[name = tensor("op_345")]; tensor var_346 = const()[name = tensor("op_346"), val = tensor(0x1p-1)]; tensor var_347 = mul(x = var_345, y = var_346)[name = tensor("op_347")]; tensor var_349 = const()[name = tensor("op_349"), val = tensor(0x1p-1)]; tensor var_350 = add(x = var_347, y = var_349)[name = tensor("op_350")]; tensor var_351 = mul(x = var_340, y = var_350)[name = tensor("op_351")]; tensor var_352_promoted = const()[name = tensor("op_352_promoted"), val = tensor(-0x1p+0)]; tensor var_353 = mul(x = var_327_promoted, y = var_352_promoted)[name = tensor("op_353")]; tensor cand_freq = add(x = var_351, y = var_353)[name = tensor("cand_freq")]; tensor var_357_axes_0 = const()[name = tensor("op_357_axes_0"), val = tensor([0])]; tensor var_357 = expand_dims(axes = var_357_axes_0, x = cand_freq)[name = tensor("op_357")]; tensor var_359 = const()[name = tensor("op_359"), val = tensor(0)]; tensor var_360_interleave_0 = const()[name = tensor("op_360_interleave_0"), val = tensor(false)]; tensor var_360 = concat(axis = var_359, interleave = var_360_interleave_0, values = (var_316, var_357))[name = tensor("op_360")]; tensor var_364 = const()[name = tensor("op_364"), val = tensor([0x0p+0, 0x0p+0])]; tensor var_366 = const()[name = tensor("op_366"), val = tensor(0)]; tensor device_context_interleave_0 = const()[name = tensor("device_context_interleave_0"), val = tensor(false)]; tensor device_context = concat(axis = var_366, interleave = device_context_interleave_0, values = (deviceContext, var_364))[name = tensor("device_context")]; tensor var_380 = const()[name = tensor("op_380"), val = tensor(-0x1p-1)]; tensor transpose_12_perm_0 = const()[name = tensor("transpose_12_perm_0"), val = tensor([1, 0])]; tensor transpose_12 = transpose(perm = transpose_12_perm_0, x = var_272)[name = tensor("transpose_46")]; tensor var_381 = greater(x = transpose_12, y = var_380)[name = tensor("op_381")]; tensor var_381_promoted_dtype_0 = const()[name = tensor("op_381_promoted_dtype_0"), val = tensor("fp32")]; tensor var_385 = const()[name = tensor("op_385"), val = tensor(-0x1p-1)]; tensor transpose_13_perm_0 = const()[name = tensor("transpose_13_perm_0"), val = tensor([1, 0])]; tensor transpose_13 = transpose(perm = transpose_13_perm_0, x = var_360)[name = tensor("transpose_45")]; tensor var_386 = greater(x = transpose_13, y = var_385)[name = tensor("op_386")]; tensor var_386_promoted_dtype_0 = const()[name = tensor("op_386_promoted_dtype_0"), val = tensor("fp32")]; tensor var_392_begin_0 = const()[name = tensor("op_392_begin_0"), val = tensor([4, 0])]; tensor var_392_end_0 = const()[name = tensor("op_392_end_0"), val = tensor([5, 1000])]; tensor var_392_end_mask_0 = const()[name = tensor("op_392_end_mask_0"), val = tensor([false, true])]; tensor var_392_squeeze_mask_0 = const()[name = tensor("op_392_squeeze_mask_0"), val = tensor([true, false])]; tensor var_392 = slice_by_index(begin = var_392_begin_0, end = var_392_end_0, end_mask = var_392_end_mask_0, squeeze_mask = var_392_squeeze_mask_0, x = var_272)[name = tensor("op_392")]; tensor var_393_keep_dims_0 = const()[name = tensor("op_393_keep_dims_0"), val = tensor(false)]; tensor var_393 = reduce_max(keep_dims = var_393_keep_dims_0, x = var_392)[name = tensor("op_393")]; tensor var_395_promoted = const()[name = tensor("op_395_promoted"), val = tensor(0x1.cp+2)]; tensor time_correction = sub(x = var_393, y = var_395_promoted)[name = tensor("time_correction")]; tensor var_401 = sub(x = var_392, y = time_correction)[name = tensor("op_401")]; tensor var_402_promoted = const()[name = tensor("op_402_promoted"), val = tensor(-0x1.cp+2)]; tensor var_403_promoted = const()[name = tensor("op_403_promoted"), val = tensor(0x1.cp+2)]; tensor clip_0 = clip(alpha = var_402_promoted, beta = var_403_promoted, x = var_401)[name = tensor("clip_0")]; tensor var_405 = exp(x = clip_0)[name = tensor("op_405")]; tensor shape_16 = const()[name = tensor("shape_16"), val = tensor([10, 1000])]; tensor slice_by_index_25 = const()[name = tensor("slice_by_index_25"), val = tensor([4000, 4001, 4002, 4003, 4004, 4005, 4006, 4007, 4008, 4009, 4010, 4011, 4012, 4013, 4014, 4015, 4016, 4017, 4018, 4019, 4020, 4021, 4022, 4023, 4024, 4025, 4026, 4027, 4028, 4029, 4030, 4031, 4032, 4033, 4034, 4035, 4036, 4037, 4038, 4039, 4040, 4041, 4042, 4043, 4044, 4045, 4046, 4047, 4048, 4049, 4050, 4051, 4052, 4053, 4054, 4055, 4056, 4057, 4058, 4059, 4060, 4061, 4062, 4063, 4064, 4065, 4066, 4067, 4068, 4069, 4070, 4071, 4072, 4073, 4074, 4075, 4076, 4077, 4078, 4079, 4080, 4081, 4082, 4083, 4084, 4085, 4086, 4087, 4088, 4089, 4090, 4091, 4092, 4093, 4094, 4095, 4096, 4097, 4098, 4099, 4100, 4101, 4102, 4103, 4104, 4105, 4106, 4107, 4108, 4109, 4110, 4111, 4112, 4113, 4114, 4115, 4116, 4117, 4118, 4119, 4120, 4121, 4122, 4123, 4124, 4125, 4126, 4127, 4128, 4129, 4130, 4131, 4132, 4133, 4134, 4135, 4136, 4137, 4138, 4139, 4140, 4141, 4142, 4143, 4144, 4145, 4146, 4147, 4148, 4149, 4150, 4151, 4152, 4153, 4154, 4155, 4156, 4157, 4158, 4159, 4160, 4161, 4162, 4163, 4164, 4165, 4166, 4167, 4168, 4169, 4170, 4171, 4172, 4173, 4174, 4175, 4176, 4177, 4178, 4179, 4180, 4181, 4182, 4183, 4184, 4185, 4186, 4187, 4188, 4189, 4190, 4191, 4192, 4193, 4194, 4195, 4196, 4197, 4198, 4199, 4200, 4201, 4202, 4203, 4204, 4205, 4206, 4207, 4208, 4209, 4210, 4211, 4212, 4213, 4214, 4215, 4216, 4217, 4218, 4219, 4220, 4221, 4222, 4223, 4224, 4225, 4226, 4227, 4228, 4229, 4230, 4231, 4232, 4233, 4234, 4235, 4236, 4237, 4238, 4239, 4240, 4241, 4242, 4243, 4244, 4245, 4246, 4247, 4248, 4249, 4250, 4251, 4252, 4253, 4254, 4255, 4256, 4257, 4258, 4259, 4260, 4261, 4262, 4263, 4264, 4265, 4266, 4267, 4268, 4269, 4270, 4271, 4272, 4273, 4274, 4275, 4276, 4277, 4278, 4279, 4280, 4281, 4282, 4283, 4284, 4285, 4286, 4287, 4288, 4289, 4290, 4291, 4292, 4293, 4294, 4295, 4296, 4297, 4298, 4299, 4300, 4301, 4302, 4303, 4304, 4305, 4306, 4307, 4308, 4309, 4310, 4311, 4312, 4313, 4314, 4315, 4316, 4317, 4318, 4319, 4320, 4321, 4322, 4323, 4324, 4325, 4326, 4327, 4328, 4329, 4330, 4331, 4332, 4333, 4334, 4335, 4336, 4337, 4338, 4339, 4340, 4341, 4342, 4343, 4344, 4345, 4346, 4347, 4348, 4349, 4350, 4351, 4352, 4353, 4354, 4355, 4356, 4357, 4358, 4359, 4360, 4361, 4362, 4363, 4364, 4365, 4366, 4367, 4368, 4369, 4370, 4371, 4372, 4373, 4374, 4375, 4376, 4377, 4378, 4379, 4380, 4381, 4382, 4383, 4384, 4385, 4386, 4387, 4388, 4389, 4390, 4391, 4392, 4393, 4394, 4395, 4396, 4397, 4398, 4399, 4400, 4401, 4402, 4403, 4404, 4405, 4406, 4407, 4408, 4409, 4410, 4411, 4412, 4413, 4414, 4415, 4416, 4417, 4418, 4419, 4420, 4421, 4422, 4423, 4424, 4425, 4426, 4427, 4428, 4429, 4430, 4431, 4432, 4433, 4434, 4435, 4436, 4437, 4438, 4439, 4440, 4441, 4442, 4443, 4444, 4445, 4446, 4447, 4448, 4449, 4450, 4451, 4452, 4453, 4454, 4455, 4456, 4457, 4458, 4459, 4460, 4461, 4462, 4463, 4464, 4465, 4466, 4467, 4468, 4469, 4470, 4471, 4472, 4473, 4474, 4475, 4476, 4477, 4478, 4479, 4480, 4481, 4482, 4483, 4484, 4485, 4486, 4487, 4488, 4489, 4490, 4491, 4492, 4493, 4494, 4495, 4496, 4497, 4498, 4499, 4500, 4501, 4502, 4503, 4504, 4505, 4506, 4507, 4508, 4509, 4510, 4511, 4512, 4513, 4514, 4515, 4516, 4517, 4518, 4519, 4520, 4521, 4522, 4523, 4524, 4525, 4526, 4527, 4528, 4529, 4530, 4531, 4532, 4533, 4534, 4535, 4536, 4537, 4538, 4539, 4540, 4541, 4542, 4543, 4544, 4545, 4546, 4547, 4548, 4549, 4550, 4551, 4552, 4553, 4554, 4555, 4556, 4557, 4558, 4559, 4560, 4561, 4562, 4563, 4564, 4565, 4566, 4567, 4568, 4569, 4570, 4571, 4572, 4573, 4574, 4575, 4576, 4577, 4578, 4579, 4580, 4581, 4582, 4583, 4584, 4585, 4586, 4587, 4588, 4589, 4590, 4591, 4592, 4593, 4594, 4595, 4596, 4597, 4598, 4599, 4600, 4601, 4602, 4603, 4604, 4605, 4606, 4607, 4608, 4609, 4610, 4611, 4612, 4613, 4614, 4615, 4616, 4617, 4618, 4619, 4620, 4621, 4622, 4623, 4624, 4625, 4626, 4627, 4628, 4629, 4630, 4631, 4632, 4633, 4634, 4635, 4636, 4637, 4638, 4639, 4640, 4641, 4642, 4643, 4644, 4645, 4646, 4647, 4648, 4649, 4650, 4651, 4652, 4653, 4654, 4655, 4656, 4657, 4658, 4659, 4660, 4661, 4662, 4663, 4664, 4665, 4666, 4667, 4668, 4669, 4670, 4671, 4672, 4673, 4674, 4675, 4676, 4677, 4678, 4679, 4680, 4681, 4682, 4683, 4684, 4685, 4686, 4687, 4688, 4689, 4690, 4691, 4692, 4693, 4694, 4695, 4696, 4697, 4698, 4699, 4700, 4701, 4702, 4703, 4704, 4705, 4706, 4707, 4708, 4709, 4710, 4711, 4712, 4713, 4714, 4715, 4716, 4717, 4718, 4719, 4720, 4721, 4722, 4723, 4724, 4725, 4726, 4727, 4728, 4729, 4730, 4731, 4732, 4733, 4734, 4735, 4736, 4737, 4738, 4739, 4740, 4741, 4742, 4743, 4744, 4745, 4746, 4747, 4748, 4749, 4750, 4751, 4752, 4753, 4754, 4755, 4756, 4757, 4758, 4759, 4760, 4761, 4762, 4763, 4764, 4765, 4766, 4767, 4768, 4769, 4770, 4771, 4772, 4773, 4774, 4775, 4776, 4777, 4778, 4779, 4780, 4781, 4782, 4783, 4784, 4785, 4786, 4787, 4788, 4789, 4790, 4791, 4792, 4793, 4794, 4795, 4796, 4797, 4798, 4799, 4800, 4801, 4802, 4803, 4804, 4805, 4806, 4807, 4808, 4809, 4810, 4811, 4812, 4813, 4814, 4815, 4816, 4817, 4818, 4819, 4820, 4821, 4822, 4823, 4824, 4825, 4826, 4827, 4828, 4829, 4830, 4831, 4832, 4833, 4834, 4835, 4836, 4837, 4838, 4839, 4840, 4841, 4842, 4843, 4844, 4845, 4846, 4847, 4848, 4849, 4850, 4851, 4852, 4853, 4854, 4855, 4856, 4857, 4858, 4859, 4860, 4861, 4862, 4863, 4864, 4865, 4866, 4867, 4868, 4869, 4870, 4871, 4872, 4873, 4874, 4875, 4876, 4877, 4878, 4879, 4880, 4881, 4882, 4883, 4884, 4885, 4886, 4887, 4888, 4889, 4890, 4891, 4892, 4893, 4894, 4895, 4896, 4897, 4898, 4899, 4900, 4901, 4902, 4903, 4904, 4905, 4906, 4907, 4908, 4909, 4910, 4911, 4912, 4913, 4914, 4915, 4916, 4917, 4918, 4919, 4920, 4921, 4922, 4923, 4924, 4925, 4926, 4927, 4928, 4929, 4930, 4931, 4932, 4933, 4934, 4935, 4936, 4937, 4938, 4939, 4940, 4941, 4942, 4943, 4944, 4945, 4946, 4947, 4948, 4949, 4950, 4951, 4952, 4953, 4954, 4955, 4956, 4957, 4958, 4959, 4960, 4961, 4962, 4963, 4964, 4965, 4966, 4967, 4968, 4969, 4970, 4971, 4972, 4973, 4974, 4975, 4976, 4977, 4978, 4979, 4980, 4981, 4982, 4983, 4984, 4985, 4986, 4987, 4988, 4989, 4990, 4991, 4992, 4993, 4994, 4995, 4996, 4997, 4998, 4999])]; tensor reshape_8_shape_0 = const()[name = tensor("reshape_8_shape_0"), val = tensor([-1])]; tensor reshape_8 = reshape(shape = reshape_8_shape_0, x = var_272)[name = tensor("reshape_8")]; tensor scatter_1_mode_0 = const()[name = tensor("scatter_1_mode_0"), val = tensor("update")]; tensor scatter_1_axis_0 = const()[name = tensor("scatter_1_axis_0"), val = tensor(0)]; tensor scatter_1 = scatter(axis = scatter_1_axis_0, data = reshape_8, indices = slice_by_index_25, mode = scatter_1_mode_0, updates = var_405)[name = tensor("scatter_1")]; tensor reshape_9 = reshape(shape = shape_16, x = scatter_1)[name = tensor("reshape_9")]; tensor candidate_interactions_13_perm_0 = const()[name = tensor("candidate_interactions_13_perm_0"), val = tensor([1, 0])]; tensor var_416_begin_0 = const()[name = tensor("op_416_begin_0"), val = tensor([3, 0])]; tensor var_416_end_0 = const()[name = tensor("op_416_end_0"), val = tensor([4, 1000])]; tensor var_416_end_mask_0 = const()[name = tensor("op_416_end_mask_0"), val = tensor([false, true])]; tensor var_416_squeeze_mask_0 = const()[name = tensor("op_416_squeeze_mask_0"), val = tensor([true, false])]; tensor var_416 = slice_by_index(begin = var_416_begin_0, end = var_416_end_0, end_mask = var_416_end_mask_0, squeeze_mask = var_416_squeeze_mask_0, x = var_360)[name = tensor("op_416")]; tensor var_418 = sub(x = var_416, y = time_correction)[name = tensor("op_418")]; tensor var_419_promoted = const()[name = tensor("op_419_promoted"), val = tensor(-0x1.cp+2)]; tensor var_420_promoted = const()[name = tensor("op_420_promoted"), val = tensor(0x1.cp+2)]; tensor clip_1 = clip(alpha = var_419_promoted, beta = var_420_promoted, x = var_418)[name = tensor("clip_1")]; tensor var_422 = exp(x = clip_1)[name = tensor("op_422")]; tensor shape_17 = const()[name = tensor("shape_17"), val = tensor([9, 1000])]; tensor slice_by_index_26 = const()[name = tensor("slice_by_index_26"), val = tensor([3000, 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, 3009, 3010, 3011, 3012, 3013, 3014, 3015, 3016, 3017, 3018, 3019, 3020, 3021, 3022, 3023, 3024, 3025, 3026, 3027, 3028, 3029, 3030, 3031, 3032, 3033, 3034, 3035, 3036, 3037, 3038, 3039, 3040, 3041, 3042, 3043, 3044, 3045, 3046, 3047, 3048, 3049, 3050, 3051, 3052, 3053, 3054, 3055, 3056, 3057, 3058, 3059, 3060, 3061, 3062, 3063, 3064, 3065, 3066, 3067, 3068, 3069, 3070, 3071, 3072, 3073, 3074, 3075, 3076, 3077, 3078, 3079, 3080, 3081, 3082, 3083, 3084, 3085, 3086, 3087, 3088, 3089, 3090, 3091, 3092, 3093, 3094, 3095, 3096, 3097, 3098, 3099, 3100, 3101, 3102, 3103, 3104, 3105, 3106, 3107, 3108, 3109, 3110, 3111, 3112, 3113, 3114, 3115, 3116, 3117, 3118, 3119, 3120, 3121, 3122, 3123, 3124, 3125, 3126, 3127, 3128, 3129, 3130, 3131, 3132, 3133, 3134, 3135, 3136, 3137, 3138, 3139, 3140, 3141, 3142, 3143, 3144, 3145, 3146, 3147, 3148, 3149, 3150, 3151, 3152, 3153, 3154, 3155, 3156, 3157, 3158, 3159, 3160, 3161, 3162, 3163, 3164, 3165, 3166, 3167, 3168, 3169, 3170, 3171, 3172, 3173, 3174, 3175, 3176, 3177, 3178, 3179, 3180, 3181, 3182, 3183, 3184, 3185, 3186, 3187, 3188, 3189, 3190, 3191, 3192, 3193, 3194, 3195, 3196, 3197, 3198, 3199, 3200, 3201, 3202, 3203, 3204, 3205, 3206, 3207, 3208, 3209, 3210, 3211, 3212, 3213, 3214, 3215, 3216, 3217, 3218, 3219, 3220, 3221, 3222, 3223, 3224, 3225, 3226, 3227, 3228, 3229, 3230, 3231, 3232, 3233, 3234, 3235, 3236, 3237, 3238, 3239, 3240, 3241, 3242, 3243, 3244, 3245, 3246, 3247, 3248, 3249, 3250, 3251, 3252, 3253, 3254, 3255, 3256, 3257, 3258, 3259, 3260, 3261, 3262, 3263, 3264, 3265, 3266, 3267, 3268, 3269, 3270, 3271, 3272, 3273, 3274, 3275, 3276, 3277, 3278, 3279, 3280, 3281, 3282, 3283, 3284, 3285, 3286, 3287, 3288, 3289, 3290, 3291, 3292, 3293, 3294, 3295, 3296, 3297, 3298, 3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308, 3309, 3310, 3311, 3312, 3313, 3314, 3315, 3316, 3317, 3318, 3319, 3320, 3321, 3322, 3323, 3324, 3325, 3326, 3327, 3328, 3329, 3330, 3331, 3332, 3333, 3334, 3335, 3336, 3337, 3338, 3339, 3340, 3341, 3342, 3343, 3344, 3345, 3346, 3347, 3348, 3349, 3350, 3351, 3352, 3353, 3354, 3355, 3356, 3357, 3358, 3359, 3360, 3361, 3362, 3363, 3364, 3365, 3366, 3367, 3368, 3369, 3370, 3371, 3372, 3373, 3374, 3375, 3376, 3377, 3378, 3379, 3380, 3381, 3382, 3383, 3384, 3385, 3386, 3387, 3388, 3389, 3390, 3391, 3392, 3393, 3394, 3395, 3396, 3397, 3398, 3399, 3400, 3401, 3402, 3403, 3404, 3405, 3406, 3407, 3408, 3409, 3410, 3411, 3412, 3413, 3414, 3415, 3416, 3417, 3418, 3419, 3420, 3421, 3422, 3423, 3424, 3425, 3426, 3427, 3428, 3429, 3430, 3431, 3432, 3433, 3434, 3435, 3436, 3437, 3438, 3439, 3440, 3441, 3442, 3443, 3444, 3445, 3446, 3447, 3448, 3449, 3450, 3451, 3452, 3453, 3454, 3455, 3456, 3457, 3458, 3459, 3460, 3461, 3462, 3463, 3464, 3465, 3466, 3467, 3468, 3469, 3470, 3471, 3472, 3473, 3474, 3475, 3476, 3477, 3478, 3479, 3480, 3481, 3482, 3483, 3484, 3485, 3486, 3487, 3488, 3489, 3490, 3491, 3492, 3493, 3494, 3495, 3496, 3497, 3498, 3499, 3500, 3501, 3502, 3503, 3504, 3505, 3506, 3507, 3508, 3509, 3510, 3511, 3512, 3513, 3514, 3515, 3516, 3517, 3518, 3519, 3520, 3521, 3522, 3523, 3524, 3525, 3526, 3527, 3528, 3529, 3530, 3531, 3532, 3533, 3534, 3535, 3536, 3537, 3538, 3539, 3540, 3541, 3542, 3543, 3544, 3545, 3546, 3547, 3548, 3549, 3550, 3551, 3552, 3553, 3554, 3555, 3556, 3557, 3558, 3559, 3560, 3561, 3562, 3563, 3564, 3565, 3566, 3567, 3568, 3569, 3570, 3571, 3572, 3573, 3574, 3575, 3576, 3577, 3578, 3579, 3580, 3581, 3582, 3583, 3584, 3585, 3586, 3587, 3588, 3589, 3590, 3591, 3592, 3593, 3594, 3595, 3596, 3597, 3598, 3599, 3600, 3601, 3602, 3603, 3604, 3605, 3606, 3607, 3608, 3609, 3610, 3611, 3612, 3613, 3614, 3615, 3616, 3617, 3618, 3619, 3620, 3621, 3622, 3623, 3624, 3625, 3626, 3627, 3628, 3629, 3630, 3631, 3632, 3633, 3634, 3635, 3636, 3637, 3638, 3639, 3640, 3641, 3642, 3643, 3644, 3645, 3646, 3647, 3648, 3649, 3650, 3651, 3652, 3653, 3654, 3655, 3656, 3657, 3658, 3659, 3660, 3661, 3662, 3663, 3664, 3665, 3666, 3667, 3668, 3669, 3670, 3671, 3672, 3673, 3674, 3675, 3676, 3677, 3678, 3679, 3680, 3681, 3682, 3683, 3684, 3685, 3686, 3687, 3688, 3689, 3690, 3691, 3692, 3693, 3694, 3695, 3696, 3697, 3698, 3699, 3700, 3701, 3702, 3703, 3704, 3705, 3706, 3707, 3708, 3709, 3710, 3711, 3712, 3713, 3714, 3715, 3716, 3717, 3718, 3719, 3720, 3721, 3722, 3723, 3724, 3725, 3726, 3727, 3728, 3729, 3730, 3731, 3732, 3733, 3734, 3735, 3736, 3737, 3738, 3739, 3740, 3741, 3742, 3743, 3744, 3745, 3746, 3747, 3748, 3749, 3750, 3751, 3752, 3753, 3754, 3755, 3756, 3757, 3758, 3759, 3760, 3761, 3762, 3763, 3764, 3765, 3766, 3767, 3768, 3769, 3770, 3771, 3772, 3773, 3774, 3775, 3776, 3777, 3778, 3779, 3780, 3781, 3782, 3783, 3784, 3785, 3786, 3787, 3788, 3789, 3790, 3791, 3792, 3793, 3794, 3795, 3796, 3797, 3798, 3799, 3800, 3801, 3802, 3803, 3804, 3805, 3806, 3807, 3808, 3809, 3810, 3811, 3812, 3813, 3814, 3815, 3816, 3817, 3818, 3819, 3820, 3821, 3822, 3823, 3824, 3825, 3826, 3827, 3828, 3829, 3830, 3831, 3832, 3833, 3834, 3835, 3836, 3837, 3838, 3839, 3840, 3841, 3842, 3843, 3844, 3845, 3846, 3847, 3848, 3849, 3850, 3851, 3852, 3853, 3854, 3855, 3856, 3857, 3858, 3859, 3860, 3861, 3862, 3863, 3864, 3865, 3866, 3867, 3868, 3869, 3870, 3871, 3872, 3873, 3874, 3875, 3876, 3877, 3878, 3879, 3880, 3881, 3882, 3883, 3884, 3885, 3886, 3887, 3888, 3889, 3890, 3891, 3892, 3893, 3894, 3895, 3896, 3897, 3898, 3899, 3900, 3901, 3902, 3903, 3904, 3905, 3906, 3907, 3908, 3909, 3910, 3911, 3912, 3913, 3914, 3915, 3916, 3917, 3918, 3919, 3920, 3921, 3922, 3923, 3924, 3925, 3926, 3927, 3928, 3929, 3930, 3931, 3932, 3933, 3934, 3935, 3936, 3937, 3938, 3939, 3940, 3941, 3942, 3943, 3944, 3945, 3946, 3947, 3948, 3949, 3950, 3951, 3952, 3953, 3954, 3955, 3956, 3957, 3958, 3959, 3960, 3961, 3962, 3963, 3964, 3965, 3966, 3967, 3968, 3969, 3970, 3971, 3972, 3973, 3974, 3975, 3976, 3977, 3978, 3979, 3980, 3981, 3982, 3983, 3984, 3985, 3986, 3987, 3988, 3989, 3990, 3991, 3992, 3993, 3994, 3995, 3996, 3997, 3998, 3999])]; tensor reshape_13_shape_0 = const()[name = tensor("reshape_13_shape_0"), val = tensor([-1])]; tensor reshape_13 = reshape(shape = reshape_13_shape_0, x = var_360)[name = tensor("reshape_13")]; tensor scatter_2_mode_0 = const()[name = tensor("scatter_2_mode_0"), val = tensor("update")]; tensor scatter_2_axis_0 = const()[name = tensor("scatter_2_axis_0"), val = tensor(0)]; tensor scatter_2 = scatter(axis = scatter_2_axis_0, data = reshape_13, indices = slice_by_index_26, mode = scatter_2_mode_0, updates = var_422)[name = tensor("scatter_2")]; tensor reshape_14 = reshape(shape = shape_17, x = scatter_2)[name = tensor("reshape_14")]; tensor tuple_alignments_perm_0 = const()[name = tensor("tuple_alignments_perm_0"), val = tensor([1, 0])]; tensor candidate_interactions_13 = transpose(perm = candidate_interactions_13_perm_0, x = reshape_9)[name = tensor("transpose_44")]; tensor var_381_promoted = cast(dtype = var_381_promoted_dtype_0, x = var_381)[name = tensor("cast_245")]; tensor var_431 = mul(x = var_381_promoted, y = candidate_interactions_13)[name = tensor("op_431")]; tensor var_432 = const()[name = tensor("op_432"), val = tensor(0x1p+0)]; tensor var_434 = sub(x = var_432, y = var_381_promoted)[name = tensor("op_434")]; tensor var_435 = const()[name = tensor("op_435"), val = tensor(-0x1p+0)]; tensor var_436 = mul(x = var_434, y = var_435)[name = tensor("op_436")]; tensor candidate_interactions = add(x = var_431, y = var_436)[name = tensor("candidate_interactions")]; tensor tuple_alignments = transpose(perm = tuple_alignments_perm_0, x = reshape_14)[name = tensor("transpose_43")]; tensor var_386_promoted = cast(dtype = var_386_promoted_dtype_0, x = var_386)[name = tensor("cast_244")]; tensor var_439 = mul(x = var_386_promoted, y = tuple_alignments)[name = tensor("op_439")]; tensor var_440 = const()[name = tensor("op_440"), val = tensor(0x1p+0)]; tensor var_442 = sub(x = var_440, y = var_386_promoted)[name = tensor("op_442")]; tensor var_443 = const()[name = tensor("op_443"), val = tensor(-0x1p+0)]; tensor var_444 = mul(x = var_442, y = var_443)[name = tensor("op_444")]; tensor tuple_alignment = add(x = var_439, y = var_444)[name = tensor("tuple_alignment")]; tensor candidate_interactions_transpose_perm_0 = const()[name = tensor("candidate_interactions_transpose_perm_0"), val = tensor([1, 0])]; tensor var_452 = const()[name = tensor("op_452"), val = tensor(0x1p-1)]; tensor candidate_interactions_transpose = transpose(perm = candidate_interactions_transpose_perm_0, x = candidate_interactions)[name = tensor("transpose_42")]; tensor var_453 = equal(x = candidate_interactions_transpose, y = var_452)[name = tensor("op_453")]; tensor cast_6_dtype_0 = const()[name = tensor("cast_6_dtype_0"), val = tensor("fp32")]; tensor mask_1_axes_0 = const()[name = tensor("mask_1_axes_0"), val = tensor([0])]; tensor mask_1_keep_dims_0 = const()[name = tensor("mask_1_keep_dims_0"), val = tensor(false)]; tensor cast_6 = cast(dtype = cast_6_dtype_0, x = var_453)[name = tensor("cast_243")]; tensor mask_1 = reduce_sum(axes = mask_1_axes_0, keep_dims = mask_1_keep_dims_0, x = cast_6)[name = tensor("mask_1")]; tensor var_464_begin_0 = const()[name = tensor("op_464_begin_0"), val = tensor([0, 0])]; tensor var_464_end_0 = const()[name = tensor("op_464_end_0"), val = tensor([1, 1000])]; tensor var_464_end_mask_0 = const()[name = tensor("op_464_end_mask_0"), val = tensor([false, true])]; tensor var_464_squeeze_mask_0 = const()[name = tensor("op_464_squeeze_mask_0"), val = tensor([true, false])]; tensor var_464 = slice_by_index(begin = var_464_begin_0, end = var_464_end_0, end_mask = var_464_end_mask_0, squeeze_mask = var_464_squeeze_mask_0, x = candidate_interactions_transpose)[name = tensor("op_464")]; tensor var_465_promoted = const()[name = tensor("op_465_promoted"), val = tensor(0x0p+0)]; tensor var_466 = mul(x = var_464, y = var_465_promoted)[name = tensor("op_466")]; tensor zero_slice_1_axes_0 = const()[name = tensor("zero_slice_1_axes_0"), val = tensor([0])]; tensor zero_slice_1 = expand_dims(axes = zero_slice_1_axes_0, x = var_466)[name = tensor("zero_slice_1")]; tensor var_470 = const()[name = tensor("op_470"), val = tensor(0)]; tensor context_feedback_3_interleave_0 = const()[name = tensor("context_feedback_3_interleave_0"), val = tensor(false)]; tensor context_feedback_3 = concat(axis = var_470, interleave = context_feedback_3_interleave_0, values = (candidate_interactions_transpose, zero_slice_1))[name = tensor("context_feedback_3")]; tensor var_476_begin_0 = const()[name = tensor("op_476_begin_0"), val = tensor([4, 0])]; tensor var_476_end_0 = const()[name = tensor("op_476_end_0"), val = tensor([10, 1000])]; tensor var_476_end_mask_0 = const()[name = tensor("op_476_end_mask_0"), val = tensor([false, true])]; tensor var_476 = slice_by_index(begin = var_476_begin_0, end = var_476_end_0, end_mask = var_476_end_mask_0, x = context_feedback_3)[name = tensor("op_476")]; tensor context_feedback_5_perm_0 = const()[name = tensor("context_feedback_5_perm_0"), val = tensor([1, 0])]; tensor fill_0 = const()[name = tensor("fill_0"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(448)))]; tensor var_495_axes_0 = const()[name = tensor("op_495_axes_0"), val = tensor([-1])]; tensor var_495 = expand_dims(axes = var_495_axes_0, x = mask_1)[name = tensor("op_495")]; tensor var_496 = mul(x = fill_0, y = var_495)[name = tensor("op_496")]; tensor var_497_promoted = const()[name = tensor("op_497_promoted"), val = tensor(-0x1p+0)]; tensor var_498 = mul(x = var_496, y = var_497_promoted)[name = tensor("op_498")]; tensor var_499_promoted = const()[name = tensor("op_499_promoted"), val = tensor(0x1p+0)]; tensor var_501 = sub(x = var_499_promoted, y = mask_1)[name = tensor("op_501")]; tensor var_503_axes_0 = const()[name = tensor("op_503_axes_0"), val = tensor([-1])]; tensor var_503 = expand_dims(axes = var_503_axes_0, x = var_501)[name = tensor("op_503")]; tensor context_feedback_5 = transpose(perm = context_feedback_5_perm_0, x = var_476)[name = tensor("transpose_41")]; tensor var_504 = mul(x = context_feedback_5, y = var_503)[name = tensor("op_504")]; tensor context_feedback_7 = add(x = var_498, y = var_504)[name = tensor("context_feedback_7")]; tensor var_508_axes_0 = const()[name = tensor("op_508_axes_0"), val = tensor([0])]; tensor var_508 = expand_dims(axes = var_508_axes_0, x = device_context)[name = tensor("op_508")]; tensor var_510 = const()[name = tensor("op_510"), val = tensor(0)]; tensor context_feedback_9_interleave_0 = const()[name = tensor("context_feedback_9_interleave_0"), val = tensor(false)]; tensor context_feedback_9 = concat(axis = var_510, interleave = context_feedback_9_interleave_0, values = (context_feedback_7, var_508))[name = tensor("context_feedback_9")]; tensor var_512 = const()[name = tensor("op_512"), val = tensor(-0x1.8p+0)]; tensor var_513 = greater(x = context_feedback_9, y = var_512)[name = tensor("op_513")]; tensor var_513_promoted_dtype_0 = const()[name = tensor("op_513_promoted_dtype_0"), val = tensor("fp32")]; tensor var_517 = const()[name = tensor("op_517"), val = tensor(-0x1p-1)]; tensor var_518 = less(x = context_feedback_9, y = var_517)[name = tensor("op_518")]; tensor var_518_promoted_dtype_0 = const()[name = tensor("op_518_promoted_dtype_0"), val = tensor("fp32")]; tensor var_518_promoted = cast(dtype = var_518_promoted_dtype_0, x = var_518)[name = tensor("cast_241")]; tensor var_513_promoted = cast(dtype = var_513_promoted_dtype_0, x = var_513)[name = tensor("cast_242")]; tensor var_522 = mul(x = var_513_promoted, y = var_518_promoted)[name = tensor("op_522")]; tensor var_523_promoted = const()[name = tensor("op_523_promoted"), val = tensor(0x1p+0)]; tensor padded_context_mask = sub(x = var_523_promoted, y = var_522)[name = tensor("padded_context_mask")]; tensor masked_context_1 = mul(x = padded_context_mask, y = context_feedback_9)[name = tensor("masked_context_1")]; tensor var_527 = abs(x = masked_context_1)[name = tensor("op_527")]; tensor scaled_keep_dims_0 = const()[name = tensor("scaled_keep_dims_0"), val = tensor(false)]; tensor scaled = reduce_max(keep_dims = scaled_keep_dims_0, x = var_527)[name = tensor("scaled")]; tensor var_530 = const()[name = tensor("op_530"), val = tensor(0x1.0624dep-10)]; tensor var_531 = add(x = scaled, y = var_530)[name = tensor("op_531")]; tensor var_532 = real_div(x = masked_context_1, y = var_531)[name = tensor("op_532")]; tensor var_533_promoted = const()[name = tensor("op_533_promoted"), val = tensor(0x1.f4p+9)]; tensor masked_context = mul(x = var_532, y = var_533_promoted)[name = tensor("masked_context")]; tensor modified_actual_sum_axes_0 = const()[name = tensor("modified_actual_sum_axes_0"), val = tensor([0])]; tensor modified_actual_sum_keep_dims_0 = const()[name = tensor("modified_actual_sum_keep_dims_0"), val = tensor(false)]; tensor modified_actual_sum = reduce_sum(axes = modified_actual_sum_axes_0, keep_dims = modified_actual_sum_keep_dims_0, x = masked_context)[name = tensor("modified_actual_sum")]; tensor modified_actual_count_axes_0 = const()[name = tensor("modified_actual_count_axes_0"), val = tensor([0])]; tensor modified_actual_count_keep_dims_0 = const()[name = tensor("modified_actual_count_keep_dims_0"), val = tensor(false)]; tensor modified_actual_count = reduce_sum(axes = modified_actual_count_axes_0, keep_dims = modified_actual_count_keep_dims_0, x = padded_context_mask)[name = tensor("modified_actual_count")]; tensor var_548 = const()[name = tensor("op_548"), val = tensor(0x1.0624dep-10)]; tensor var_549 = add(x = modified_actual_count, y = var_548)[name = tensor("op_549")]; tensor adjusted_mean_1 = real_div(x = modified_actual_sum, y = var_549)[name = tensor("adjusted_mean_1")]; tensor _inversed_552_y_0 = const()[name = tensor("_inversed_552_y_0"), val = tensor(0x1.0624dep-10)]; tensor _inversed_552 = mul(x = adjusted_mean_1, y = _inversed_552_y_0)[name = tensor("_inversed_552")]; tensor var_553 = mul(x = _inversed_552, y = scaled)[name = tensor("op_553")]; tensor not_equal_0 = not_equal(x = var_553, y = var_553)[name = tensor("not_equal_0")]; tensor cast_13_dtype_0 = const()[name = tensor("cast_13_dtype_0"), val = tensor("int32")]; tensor cast_13 = cast(dtype = cast_13_dtype_0, x = not_equal_0)[name = tensor("cast_240")]; tensor non_zero_0 = non_zero(x = cast_13)[name = tensor("non_zero_0")]; tensor expand_dims_0 = const()[name = tensor("expand_dims_0"), val = tensor([0x0p+0])]; tensor shape_2 = shape(x = non_zero_0)[name = tensor("shape_2")]; tensor slice_by_index_4_begin_0 = const()[name = tensor("slice_by_index_4_begin_0"), val = tensor([0])]; tensor slice_by_index_4_end_0 = const()[name = tensor("slice_by_index_4_end_0"), val = tensor([0])]; tensor slice_by_index_4_squeeze_mask_0 = const()[name = tensor("slice_by_index_4_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_4 = slice_by_index(begin = slice_by_index_4_begin_0, end = slice_by_index_4_end_0, squeeze_mask = slice_by_index_4_squeeze_mask_0, x = shape_2)[name = tensor("slice_by_index_4")]; tensor expand_dims_1_axes_0 = const()[name = tensor("expand_dims_1_axes_0"), val = tensor([0])]; tensor expand_dims_1 = expand_dims(axes = expand_dims_1_axes_0, x = slice_by_index_4)[name = tensor("expand_dims_1")]; tensor tile_0 = tile(reps = expand_dims_1, x = expand_dims_0)[name = tensor("tile_0")]; tensor scatter_nd_0_mode_0 = const()[name = tensor("scatter_nd_0_mode_0"), val = tensor("update")]; tensor scatter_nd_0 = scatter_nd(data = var_553, indices = non_zero_0, mode = scatter_nd_0_mode_0, updates = tile_0)[name = tensor("scatter_nd_0")]; tensor mul_0_y_0 = const()[name = tensor("mul_0_y_0"), val = tensor(0x0p+0)]; tensor mul_0 = mul(x = var_553, y = mul_0_y_0)[name = tensor("mul_0")]; tensor not_equal_1 = not_equal(x = mul_0, y = mul_0)[name = tensor("not_equal_1")]; tensor greater_0_y_0 = const()[name = tensor("greater_0_y_0"), val = tensor(0x0p+0)]; tensor greater_0 = greater(x = var_553, y = greater_0_y_0)[name = tensor("greater_0")]; tensor logical_and_0 = logical_and(x = not_equal_1, y = greater_0)[name = tensor("logical_and_0")]; tensor less_0_y_0 = const()[name = tensor("less_0_y_0"), val = tensor(0x0p+0)]; tensor less_0 = less(x = var_553, y = less_0_y_0)[name = tensor("less_0")]; tensor logical_and_1 = logical_and(x = not_equal_1, y = less_0)[name = tensor("logical_and_1")]; tensor cast_14_dtype_0 = const()[name = tensor("cast_14_dtype_0"), val = tensor("int32")]; tensor cast_14 = cast(dtype = cast_14_dtype_0, x = logical_and_0)[name = tensor("cast_239")]; tensor non_zero_1 = non_zero(x = cast_14)[name = tensor("non_zero_1")]; tensor expand_dims_2 = const()[name = tensor("expand_dims_2"), val = tensor([0x1.fffffep+127])]; tensor shape_3 = shape(x = non_zero_1)[name = tensor("shape_3")]; tensor slice_by_index_5_begin_0 = const()[name = tensor("slice_by_index_5_begin_0"), val = tensor([0])]; tensor slice_by_index_5_end_0 = const()[name = tensor("slice_by_index_5_end_0"), val = tensor([0])]; tensor slice_by_index_5_squeeze_mask_0 = const()[name = tensor("slice_by_index_5_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_5 = slice_by_index(begin = slice_by_index_5_begin_0, end = slice_by_index_5_end_0, squeeze_mask = slice_by_index_5_squeeze_mask_0, x = shape_3)[name = tensor("slice_by_index_5")]; tensor expand_dims_3_axes_0 = const()[name = tensor("expand_dims_3_axes_0"), val = tensor([0])]; tensor expand_dims_3 = expand_dims(axes = expand_dims_3_axes_0, x = slice_by_index_5)[name = tensor("expand_dims_3")]; tensor tile_1 = tile(reps = expand_dims_3, x = expand_dims_2)[name = tensor("tile_1")]; tensor scatter_nd_1_mode_0 = const()[name = tensor("scatter_nd_1_mode_0"), val = tensor("update")]; tensor scatter_nd_1 = scatter_nd(data = scatter_nd_0, indices = non_zero_1, mode = scatter_nd_1_mode_0, updates = tile_1)[name = tensor("scatter_nd_1")]; tensor cast_15_dtype_0 = const()[name = tensor("cast_15_dtype_0"), val = tensor("int32")]; tensor cast_15 = cast(dtype = cast_15_dtype_0, x = logical_and_1)[name = tensor("cast_238")]; tensor non_zero_2 = non_zero(x = cast_15)[name = tensor("non_zero_2")]; tensor expand_dims_4 = const()[name = tensor("expand_dims_4"), val = tensor([-0x1.fffffep+127])]; tensor shape_4 = shape(x = non_zero_2)[name = tensor("shape_4")]; tensor slice_by_index_6_begin_0 = const()[name = tensor("slice_by_index_6_begin_0"), val = tensor([0])]; tensor slice_by_index_6_end_0 = const()[name = tensor("slice_by_index_6_end_0"), val = tensor([0])]; tensor slice_by_index_6_squeeze_mask_0 = const()[name = tensor("slice_by_index_6_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_6 = slice_by_index(begin = slice_by_index_6_begin_0, end = slice_by_index_6_end_0, squeeze_mask = slice_by_index_6_squeeze_mask_0, x = shape_4)[name = tensor("slice_by_index_6")]; tensor expand_dims_5_axes_0 = const()[name = tensor("expand_dims_5_axes_0"), val = tensor([0])]; tensor expand_dims_5 = expand_dims(axes = expand_dims_5_axes_0, x = slice_by_index_6)[name = tensor("expand_dims_5")]; tensor tile_2 = tile(reps = expand_dims_5, x = expand_dims_4)[name = tensor("tile_2")]; tensor scatter_nd_2_mode_0 = const()[name = tensor("scatter_nd_2_mode_0"), val = tensor("update")]; tensor scatter_nd_2 = scatter_nd(data = scatter_nd_1, indices = non_zero_2, mode = scatter_nd_2_mode_0, updates = tile_2)[name = tensor("scatter_nd_2")]; tensor var_558_promoted = const()[name = tensor("op_558_promoted"), val = tensor(0x1p+0)]; tensor var_560 = sub(x = var_558_promoted, y = padded_context_mask)[name = tensor("op_560")]; tensor var_561 = const()[name = tensor("op_561"), val = tensor(0x1.0624dep-10)]; tensor var_562 = mul(x = var_560, y = var_561)[name = tensor("op_562")]; tensor var_564 = add(x = padded_context_mask, y = var_562)[name = tensor("op_564")]; tensor log_padded_epsilon_0 = const()[name = tensor("log_padded_epsilon_0"), val = tensor(0x1p-149)]; tensor log_padded = log(epsilon = log_padded_epsilon_0, x = var_564)[name = tensor("log_padded")]; tensor var_567 = sub(x = context_feedback_9, y = scatter_nd_2)[name = tensor("op_567")]; tensor var_568 = abs(x = var_567)[name = tensor("op_568")]; tensor var_570 = const()[name = tensor("op_570"), val = tensor(0x1.0624dep-10)]; tensor var_571 = add(x = var_568, y = var_570)[name = tensor("op_571")]; tensor var_572_epsilon_0 = const()[name = tensor("op_572_epsilon_0"), val = tensor(0x1p-149)]; tensor var_572 = log(epsilon = var_572_epsilon_0, x = var_571)[name = tensor("op_572")]; tensor var_573_promoted = const()[name = tensor("op_573_promoted"), val = tensor(0x1p+1)]; tensor var_574 = mul(x = var_572, y = var_573_promoted)[name = tensor("op_574")]; tensor x_1 = add(x = var_574, y = log_padded)[name = tensor("x_1")]; tensor reduce_max_0_axes_0 = const()[name = tensor("reduce_max_0_axes_0"), val = tensor([0])]; tensor reduce_max_0_keep_dims_0 = const()[name = tensor("reduce_max_0_keep_dims_0"), val = tensor(false)]; tensor reduce_max_0 = reduce_max(axes = reduce_max_0_axes_0, keep_dims = reduce_max_0_keep_dims_0, x = x_1)[name = tensor("reduce_max_0")]; tensor var_582_axes_0 = const()[name = tensor("op_582_axes_0"), val = tensor([0])]; tensor var_582 = expand_dims(axes = var_582_axes_0, x = reduce_max_0)[name = tensor("op_582")]; tensor var_584 = sub(x = x_1, y = var_582)[name = tensor("op_584")]; tensor var_585 = exp(x = var_584)[name = tensor("op_585")]; tensor var_590_axes_0 = const()[name = tensor("op_590_axes_0"), val = tensor([0])]; tensor var_590_keep_dims_0 = const()[name = tensor("op_590_keep_dims_0"), val = tensor(false)]; tensor var_590 = reduce_sum(axes = var_590_axes_0, keep_dims = var_590_keep_dims_0, x = var_585)[name = tensor("op_590")]; tensor var_591_epsilon_0 = const()[name = tensor("op_591_epsilon_0"), val = tensor(0x1p-149)]; tensor var_591 = log(epsilon = var_591_epsilon_0, x = var_590)[name = tensor("op_591")]; tensor var_593 = add(x = reduce_max_0, y = var_591)[name = tensor("op_593")]; tensor var_597_epsilon_0 = const()[name = tensor("op_597_epsilon_0"), val = tensor(0x1p-149)]; tensor var_597 = log(epsilon = var_597_epsilon_0, x = var_549)[name = tensor("op_597")]; tensor var_599 = sub(x = var_593, y = var_597)[name = tensor("op_599")]; tensor var_600 = const()[name = tensor("op_600"), val = tensor(0x1p-1)]; tensor log_adjusted_std = mul(x = var_599, y = var_600)[name = tensor("log_adjusted_std")]; tensor var_602 = exp(x = log_adjusted_std)[name = tensor("op_602")]; tensor not_equal_2 = not_equal(x = var_602, y = var_602)[name = tensor("not_equal_2")]; tensor cast_16_dtype_0 = const()[name = tensor("cast_16_dtype_0"), val = tensor("int32")]; tensor cast_16 = cast(dtype = cast_16_dtype_0, x = not_equal_2)[name = tensor("cast_237")]; tensor non_zero_3 = non_zero(x = cast_16)[name = tensor("non_zero_3")]; tensor expand_dims_6 = const()[name = tensor("expand_dims_6"), val = tensor([0x1p+0])]; tensor shape_5 = shape(x = non_zero_3)[name = tensor("shape_5")]; tensor slice_by_index_7_begin_0 = const()[name = tensor("slice_by_index_7_begin_0"), val = tensor([0])]; tensor slice_by_index_7_end_0 = const()[name = tensor("slice_by_index_7_end_0"), val = tensor([0])]; tensor slice_by_index_7_squeeze_mask_0 = const()[name = tensor("slice_by_index_7_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_7 = slice_by_index(begin = slice_by_index_7_begin_0, end = slice_by_index_7_end_0, squeeze_mask = slice_by_index_7_squeeze_mask_0, x = shape_5)[name = tensor("slice_by_index_7")]; tensor expand_dims_7_axes_0 = const()[name = tensor("expand_dims_7_axes_0"), val = tensor([0])]; tensor expand_dims_7 = expand_dims(axes = expand_dims_7_axes_0, x = slice_by_index_7)[name = tensor("expand_dims_7")]; tensor tile_3 = tile(reps = expand_dims_7, x = expand_dims_6)[name = tensor("tile_3")]; tensor scatter_nd_3_mode_0 = const()[name = tensor("scatter_nd_3_mode_0"), val = tensor("update")]; tensor scatter_nd_3 = scatter_nd(data = var_602, indices = non_zero_3, mode = scatter_nd_3_mode_0, updates = tile_3)[name = tensor("scatter_nd_3")]; tensor mul_1_y_0 = const()[name = tensor("mul_1_y_0"), val = tensor(0x0p+0)]; tensor mul_1 = mul(x = var_602, y = mul_1_y_0)[name = tensor("mul_1")]; tensor not_equal_3 = not_equal(x = mul_1, y = mul_1)[name = tensor("not_equal_3")]; tensor greater_1_y_0 = const()[name = tensor("greater_1_y_0"), val = tensor(0x0p+0)]; tensor greater_1 = greater(x = var_602, y = greater_1_y_0)[name = tensor("greater_1")]; tensor logical_and_2 = logical_and(x = not_equal_3, y = greater_1)[name = tensor("logical_and_2")]; tensor less_1_y_0 = const()[name = tensor("less_1_y_0"), val = tensor(0x0p+0)]; tensor less_1 = less(x = var_602, y = less_1_y_0)[name = tensor("less_1")]; tensor logical_and_3 = logical_and(x = not_equal_3, y = less_1)[name = tensor("logical_and_3")]; tensor cast_17_dtype_0 = const()[name = tensor("cast_17_dtype_0"), val = tensor("int32")]; tensor cast_17 = cast(dtype = cast_17_dtype_0, x = logical_and_2)[name = tensor("cast_236")]; tensor non_zero_4 = non_zero(x = cast_17)[name = tensor("non_zero_4")]; tensor expand_dims_8 = const()[name = tensor("expand_dims_8"), val = tensor([0x1.fffffep+127])]; tensor shape_6 = shape(x = non_zero_4)[name = tensor("shape_6")]; tensor slice_by_index_8_begin_0 = const()[name = tensor("slice_by_index_8_begin_0"), val = tensor([0])]; tensor slice_by_index_8_end_0 = const()[name = tensor("slice_by_index_8_end_0"), val = tensor([0])]; tensor slice_by_index_8_squeeze_mask_0 = const()[name = tensor("slice_by_index_8_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_8 = slice_by_index(begin = slice_by_index_8_begin_0, end = slice_by_index_8_end_0, squeeze_mask = slice_by_index_8_squeeze_mask_0, x = shape_6)[name = tensor("slice_by_index_8")]; tensor expand_dims_9_axes_0 = const()[name = tensor("expand_dims_9_axes_0"), val = tensor([0])]; tensor expand_dims_9 = expand_dims(axes = expand_dims_9_axes_0, x = slice_by_index_8)[name = tensor("expand_dims_9")]; tensor tile_4 = tile(reps = expand_dims_9, x = expand_dims_8)[name = tensor("tile_4")]; tensor scatter_nd_4_mode_0 = const()[name = tensor("scatter_nd_4_mode_0"), val = tensor("update")]; tensor scatter_nd_4 = scatter_nd(data = scatter_nd_3, indices = non_zero_4, mode = scatter_nd_4_mode_0, updates = tile_4)[name = tensor("scatter_nd_4")]; tensor cast_18_dtype_0 = const()[name = tensor("cast_18_dtype_0"), val = tensor("int32")]; tensor cast_18 = cast(dtype = cast_18_dtype_0, x = logical_and_3)[name = tensor("cast_235")]; tensor non_zero_5 = non_zero(x = cast_18)[name = tensor("non_zero_5")]; tensor expand_dims_10 = const()[name = tensor("expand_dims_10"), val = tensor([-0x1.fffffep+127])]; tensor shape_7 = shape(x = non_zero_5)[name = tensor("shape_7")]; tensor slice_by_index_9_begin_0 = const()[name = tensor("slice_by_index_9_begin_0"), val = tensor([0])]; tensor slice_by_index_9_end_0 = const()[name = tensor("slice_by_index_9_end_0"), val = tensor([0])]; tensor slice_by_index_9_squeeze_mask_0 = const()[name = tensor("slice_by_index_9_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_9 = slice_by_index(begin = slice_by_index_9_begin_0, end = slice_by_index_9_end_0, squeeze_mask = slice_by_index_9_squeeze_mask_0, x = shape_7)[name = tensor("slice_by_index_9")]; tensor expand_dims_11_axes_0 = const()[name = tensor("expand_dims_11_axes_0"), val = tensor([0])]; tensor expand_dims_11 = expand_dims(axes = expand_dims_11_axes_0, x = slice_by_index_9)[name = tensor("expand_dims_11")]; tensor tile_5 = tile(reps = expand_dims_11, x = expand_dims_10)[name = tensor("tile_5")]; tensor scatter_nd_5_mode_0 = const()[name = tensor("scatter_nd_5_mode_0"), val = tensor("update")]; tensor scatter_nd_5 = scatter_nd(data = scatter_nd_4, indices = non_zero_5, mode = scatter_nd_5_mode_0, updates = tile_5)[name = tensor("scatter_nd_5")]; tensor var_608 = const()[name = tensor("op_608"), val = tensor(0x1.0c6f7ap-20)]; tensor context_sigma = add(x = scatter_nd_5, y = var_608)[name = tensor("context_sigma")]; tensor var_610 = const()[name = tensor("op_610"), val = tensor(-0x1.19999ap+0)]; tensor var_611 = greater(x = device_context, y = var_610)[name = tensor("op_611")]; tensor var_611_promoted_dtype_0 = const()[name = tensor("op_611_promoted_dtype_0"), val = tensor("fp32")]; tensor var_615 = const()[name = tensor("op_615"), val = tensor(-0x1.ccccccp-1)]; tensor var_616 = less(x = device_context, y = var_615)[name = tensor("op_616")]; tensor var_616_promoted_dtype_0 = const()[name = tensor("op_616_promoted_dtype_0"), val = tensor("fp32")]; tensor var_616_promoted = cast(dtype = var_616_promoted_dtype_0, x = var_616)[name = tensor("cast_233")]; tensor var_611_promoted = cast(dtype = var_611_promoted_dtype_0, x = var_611)[name = tensor("cast_234")]; tensor is_padding_1 = mul(x = var_611_promoted, y = var_616_promoted)[name = tensor("is_padding_1")]; tensor var_621 = const()[name = tensor("op_621"), val = tensor(0x1p+0)]; tensor is_not_padding_1 = sub(x = var_621, y = is_padding_1)[name = tensor("is_not_padding_1")]; tensor var_624 = const()[name = tensor("op_624"), val = tensor(-0x1.e848p+19)]; tensor var_625 = greater(x = context_sigma, y = var_624)[name = tensor("op_625")]; tensor var_625_promoted_dtype_0 = const()[name = tensor("op_625_promoted_dtype_0"), val = tensor("fp32")]; tensor var_625_promoted = cast(dtype = var_625_promoted_dtype_0, x = var_625)[name = tensor("cast_232")]; tensor var_626 = mul(x = is_padding_1, y = var_625_promoted)[name = tensor("op_626")]; tensor var_627 = mul(x = is_not_padding_1, y = context_sigma)[name = tensor("op_627")]; tensor padded_sigma_1 = add(x = var_626, y = var_627)[name = tensor("padded_sigma_1")]; tensor var_630 = mul(x = scatter_nd_2, y = is_not_padding_1)[name = tensor("op_630")]; tensor var_632 = sub(x = device_context, y = var_630)[name = tensor("op_632")]; tensor var_633 = real_div(x = var_632, y = padded_sigma_1)[name = tensor("op_633")]; tensor var_634 = mul(x = is_not_padding_1, y = var_633)[name = tensor("op_634")]; tensor var_635 = mul(x = is_padding_1, y = device_context)[name = tensor("op_635")]; tensor context = add(x = var_634, y = var_635)[name = tensor("context")]; tensor var_640_begin_0 = const()[name = tensor("op_640_begin_0"), val = tensor([0, 0])]; tensor var_640_end_0 = const()[name = tensor("op_640_end_0"), val = tensor([1, 15])]; tensor var_640_end_mask_0 = const()[name = tensor("op_640_end_mask_0"), val = tensor([false, true])]; tensor var_640_squeeze_mask_0 = const()[name = tensor("op_640_squeeze_mask_0"), val = tensor([true, false])]; tensor var_640 = slice_by_index(begin = var_640_begin_0, end = var_640_end_0, end_mask = var_640_end_mask_0, squeeze_mask = var_640_squeeze_mask_0, x = x_3)[name = tensor("op_640")]; tensor var_641 = const()[name = tensor("op_641"), val = tensor(-0x1.e848p+19)]; tensor var_642 = greater(x = var_640, y = var_641)[name = tensor("op_642")]; tensor var_642_promoted_dtype_0 = const()[name = tensor("op_642_promoted_dtype_0"), val = tensor("int32")]; tensor var_645 = const()[name = tensor("op_645"), val = tensor(-0x1.e848p+19)]; tensor ones_1_promoted_dtype_0 = const()[name = tensor("ones_1_promoted_dtype_0"), val = tensor("fp32")]; tensor var_642_to_fp32 = cast(dtype = ones_1_promoted_dtype_0, x = var_642)[name = tensor("cast_230")]; tensor small_1 = mul(x = var_642_to_fp32, y = var_645)[name = tensor("small_1")]; tensor var_647 = const()[name = tensor("op_647"), val = tensor(0x1.e848p+19)]; tensor big_1 = mul(x = var_642_to_fp32, y = var_647)[name = tensor("big_1")]; tensor var_649 = const()[name = tensor("op_649"), val = tensor(0)]; tensor var_642_promoted = cast(dtype = var_642_promoted_dtype_0, x = var_642)[name = tensor("cast_231")]; tensor zeros_1 = mul(x = var_642_promoted, y = var_649)[name = tensor("zeros_1")]; tensor var_652_axes_0 = const()[name = tensor("op_652_axes_0"), val = tensor([0])]; tensor var_652 = expand_dims(axes = var_652_axes_0, x = small_1)[name = tensor("op_652")]; tensor var_654_axes_0 = const()[name = tensor("op_654_axes_0"), val = tensor([0])]; tensor var_654 = expand_dims(axes = var_654_axes_0, x = big_1)[name = tensor("op_654")]; tensor var_656 = const()[name = tensor("op_656"), val = tensor(0)]; tensor x_padded_1_interleave_0 = const()[name = tensor("x_padded_1_interleave_0"), val = tensor(false)]; tensor x_padded_1 = concat(axis = var_656, interleave = x_padded_1_interleave_0, values = (var_652, x_3, var_654))[name = tensor("x_padded_1")]; tensor var_658 = const()[name = tensor("op_658"), val = tensor(0)]; tensor logical_not_0 = const()[name = tensor("logical_not_0"), val = tensor(true)]; tensor i_1 = argsort(ascending = logical_not_0, axis = var_658, x = x_padded_1)[name = tensor("i_1")]; tensor by_x_1 = gather_along_axis(axis = var_658, indices = i_1, x = x_padded_1)[name = tensor("by_x_1")]; tensor var_666_begin_0 = const()[name = tensor("op_666_begin_0"), val = tensor([1, 0])]; tensor var_666_end_0 = const()[name = tensor("op_666_end_0"), val = tensor([51, 15])]; tensor var_666_end_mask_0 = const()[name = tensor("op_666_end_mask_0"), val = tensor([false, true])]; tensor var_666 = slice_by_index(begin = var_666_begin_0, end = var_666_end_0, end_mask = var_666_end_mask_0, x = by_x_1)[name = tensor("op_666")]; tensor var_671_begin_0 = const()[name = tensor("op_671_begin_0"), val = tensor([0, 0])]; tensor var_671_end_0 = const()[name = tensor("op_671_end_0"), val = tensor([50, 15])]; tensor var_671_end_mask_0 = const()[name = tensor("op_671_end_mask_0"), val = tensor([false, true])]; tensor var_671 = slice_by_index(begin = var_671_begin_0, end = var_671_end_0, end_mask = var_671_end_mask_0, x = by_x_1)[name = tensor("op_671")]; tensor var_673 = sub(x = var_666, y = var_671)[name = tensor("op_673")]; tensor var_674_promoted = const()[name = tensor("op_674_promoted"), val = tensor(0x0p+0)]; tensor var_675 = greater(x = var_673, y = var_674_promoted)[name = tensor("op_675")]; tensor var_675_promoted_dtype_0 = const()[name = tensor("op_675_promoted_dtype_0"), val = tensor("int32")]; tensor var_679_axes_0 = const()[name = tensor("op_679_axes_0"), val = tensor([0])]; tensor var_679 = expand_dims(axes = var_679_axes_0, x = zeros_1)[name = tensor("op_679")]; tensor var_683 = const()[name = tensor("op_683"), val = tensor(0)]; tensor mask_5_interleave_0 = const()[name = tensor("mask_5_interleave_0"), val = tensor(false)]; tensor var_675_promoted = cast(dtype = var_675_promoted_dtype_0, x = var_675)[name = tensor("cast_229")]; tensor mask_5 = concat(axis = var_683, interleave = mask_5_interleave_0, values = (var_679, var_675_promoted, var_679))[name = tensor("mask_5")]; tensor mask_5_promoted_dtype_0 = const()[name = tensor("mask_5_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_5_promoted = cast(dtype = mask_5_promoted_dtype_0, x = mask_5)[name = tensor("cast_228")]; tensor var_685 = mul(x = by_x_1, y = mask_5_promoted)[name = tensor("op_685")]; tensor var_686 = const()[name = tensor("op_686"), val = tensor(0)]; tensor logical_not_1 = const()[name = tensor("logical_not_1"), val = tensor(true)]; tensor var_688 = argsort(ascending = logical_not_1, axis = var_686, x = i_1)[name = tensor("op_688")]; tensor var_689 = const()[name = tensor("op_689"), val = tensor(0)]; tensor unique_1 = gather_along_axis(axis = var_689, indices = var_688, x = var_685)[name = tensor("unique_1")]; tensor unique_candidates_begin_0 = const()[name = tensor("unique_candidates_begin_0"), val = tensor([1, 0])]; tensor unique_candidates_end_0 = const()[name = tensor("unique_candidates_end_0"), val = tensor([51, 15])]; tensor unique_candidates_end_mask_0 = const()[name = tensor("unique_candidates_end_mask_0"), val = tensor([false, true])]; tensor unique_candidates = slice_by_index(begin = unique_candidates_begin_0, end = unique_candidates_end_0, end_mask = unique_candidates_end_mask_0, x = unique_1)[name = tensor("unique_candidates")]; tensor var_710_begin_0 = const()[name = tensor("op_710_begin_0"), val = tensor([1, 0])]; tensor var_710_end_0 = const()[name = tensor("op_710_end_0"), val = tensor([2, 1000])]; tensor var_710_end_mask_0 = const()[name = tensor("op_710_end_mask_0"), val = tensor([false, true])]; tensor var_710 = slice_by_index(begin = var_710_begin_0, end = var_710_end_0, end_mask = var_710_end_mask_0, x = candidate_interactions_transpose)[name = tensor("op_710")]; tensor alignment_feedback_1 = squeeze(x = var_710)[name = tensor("alignment_feedback_1")]; tensor var_719_begin_0 = const()[name = tensor("op_719_begin_0"), val = tensor([2, 0])]; tensor var_719_end_0 = const()[name = tensor("op_719_end_0"), val = tensor([3, 1000])]; tensor var_719_end_mask_0 = const()[name = tensor("op_719_end_mask_0"), val = tensor([false, true])]; tensor var_719 = slice_by_index(begin = var_719_begin_0, end = var_719_end_0, end_mask = var_719_end_mask_0, x = candidate_interactions_transpose)[name = tensor("op_719")]; tensor hist_parameter_names = squeeze(x = var_719)[name = tensor("hist_parameter_names")]; tensor var_742 = const()[name = tensor("op_742"), val = tensor(-0x1.8p+0)]; tensor var_743 = greater(x = context_feedback_5, y = var_742)[name = tensor("op_743")]; tensor var_743_promoted_dtype_0 = const()[name = tensor("op_743_promoted_dtype_0"), val = tensor("fp32")]; tensor var_747 = const()[name = tensor("op_747"), val = tensor(-0x1p-1)]; tensor var_748 = less(x = context_feedback_5, y = var_747)[name = tensor("op_748")]; tensor var_748_promoted_dtype_0 = const()[name = tensor("op_748_promoted_dtype_0"), val = tensor("fp32")]; tensor var_748_promoted = cast(dtype = var_748_promoted_dtype_0, x = var_748)[name = tensor("cast_226")]; tensor var_743_promoted = cast(dtype = var_743_promoted_dtype_0, x = var_743)[name = tensor("cast_227")]; tensor var_752 = mul(x = var_743_promoted, y = var_748_promoted)[name = tensor("op_752")]; tensor var_753_promoted = const()[name = tensor("op_753_promoted"), val = tensor(0x1p+0)]; tensor not_padded_feedback_1 = sub(x = var_753_promoted, y = var_752)[name = tensor("not_padded_feedback_1")]; tensor var_756 = const()[name = tensor("op_756"), val = tensor(-0x1.19999ap+0)]; tensor var_757 = greater(x = context_feedback_5, y = var_756)[name = tensor("op_757")]; tensor var_757_promoted_dtype_0 = const()[name = tensor("op_757_promoted_dtype_0"), val = tensor("fp32")]; tensor var_761 = const()[name = tensor("op_761"), val = tensor(-0x1.ccccccp-1)]; tensor var_762 = less(x = context_feedback_5, y = var_761)[name = tensor("op_762")]; tensor var_762_promoted_dtype_0 = const()[name = tensor("op_762_promoted_dtype_0"), val = tensor("fp32")]; tensor var_762_promoted = cast(dtype = var_762_promoted_dtype_0, x = var_762)[name = tensor("cast_224")]; tensor var_757_promoted = cast(dtype = var_757_promoted_dtype_0, x = var_757)[name = tensor("cast_225")]; tensor is_padding_3 = mul(x = var_757_promoted, y = var_762_promoted)[name = tensor("is_padding_3")]; tensor var_767 = const()[name = tensor("op_767"), val = tensor(0x1p+0)]; tensor is_not_padding_3 = sub(x = var_767, y = is_padding_3)[name = tensor("is_not_padding_3")]; tensor var_772 = mul(x = is_padding_3, y = var_625_promoted)[name = tensor("op_772")]; tensor var_773 = mul(x = is_not_padding_3, y = context_sigma)[name = tensor("op_773")]; tensor padded_sigma_3 = add(x = var_772, y = var_773)[name = tensor("padded_sigma_3")]; tensor var_776 = mul(x = scatter_nd_2, y = is_not_padding_3)[name = tensor("op_776")]; tensor var_778 = sub(x = context_feedback_5, y = var_776)[name = tensor("op_778")]; tensor var_779 = real_div(x = var_778, y = padded_sigma_3)[name = tensor("op_779")]; tensor var_780 = mul(x = is_not_padding_3, y = var_779)[name = tensor("op_780")]; tensor var_781 = mul(x = is_padding_3, y = context_feedback_5)[name = tensor("op_781")]; tensor context_feedback_17 = add(x = var_780, y = var_781)[name = tensor("context_feedback_17")]; tensor var_784 = const()[name = tensor("op_784"), val = tensor(-0x1p-1)]; tensor var_785 = greater(x = alignment_feedback_1, y = var_784)[name = tensor("op_785")]; tensor var_785_promoted_dtype_0 = const()[name = tensor("op_785_promoted_dtype_0"), val = tensor("fp32")]; tensor var_790 = const()[name = tensor("op_790"), val = tensor(0x1p-1)]; tensor var_791 = sub(x = alignment_feedback_1, y = var_790)[name = tensor("op_791")]; tensor var_785_promoted = cast(dtype = var_785_promoted_dtype_0, x = var_785)[name = tensor("cast_223")]; tensor var_792 = mul(x = var_785_promoted, y = var_791)[name = tensor("op_792")]; tensor var_793_promoted = const()[name = tensor("op_793_promoted"), val = tensor(0x1p+1)]; tensor alignment_feedback = mul(x = var_792, y = var_793_promoted)[name = tensor("alignment_feedback")]; tensor pn_which_column_1_axes_0 = const()[name = tensor("pn_which_column_1_axes_0"), val = tensor([0])]; tensor pn_which_column_1 = expand_dims(axes = pn_which_column_1_axes_0, x = parameterName)[name = tensor("pn_which_column_1")]; tensor var_798_axes_0 = const()[name = tensor("op_798_axes_0"), val = tensor([1])]; tensor var_798 = expand_dims(axes = var_798_axes_0, x = hist_parameter_names)[name = tensor("op_798")]; tensor var_800 = sub(x = pn_which_column_1, y = var_798)[name = tensor("op_800")]; tensor var_801 = abs(x = var_800)[name = tensor("op_801")]; tensor var_802 = const()[name = tensor("op_802"), val = tensor(0x1.a36e2ep-15)]; tensor var_803 = less(x = var_801, y = var_802)[name = tensor("op_803")]; tensor var_803_promoted_dtype_0 = const()[name = tensor("op_803_promoted_dtype_0"), val = tensor("fp32")]; tensor var_808_axes_0 = const()[name = tensor("op_808_axes_0"), val = tensor([0])]; tensor var_808 = expand_dims(axes = var_808_axes_0, x = x_3)[name = tensor("op_808")]; tensor var_810_axes_0 = const()[name = tensor("op_810_axes_0"), val = tensor([1])]; tensor var_810 = expand_dims(axes = var_810_axes_0, x = var_464)[name = tensor("op_810")]; tensor var_812_axes_0 = const()[name = tensor("op_812_axes_0"), val = tensor([1])]; tensor var_812 = expand_dims(axes = var_812_axes_0, x = var_810)[name = tensor("op_812")]; tensor var_814 = sub(x = var_808, y = var_812)[name = tensor("op_814")]; tensor var_815 = abs(x = var_814)[name = tensor("op_815")]; tensor var_816 = const()[name = tensor("op_816"), val = tensor(0x1.a36e2ep-15)]; tensor var_817 = less(x = var_815, y = var_816)[name = tensor("op_817")]; tensor var_817_promoted_dtype_0 = const()[name = tensor("op_817_promoted_dtype_0"), val = tensor("fp32")]; tensor reduce_max_1_axes_0 = const()[name = tensor("reduce_max_1_axes_0"), val = tensor([1])]; tensor reduce_max_1_keep_dims_0 = const()[name = tensor("reduce_max_1_keep_dims_0"), val = tensor(false)]; tensor var_817_promoted = cast(dtype = var_817_promoted_dtype_0, x = var_817)[name = tensor("cast_221")]; tensor reduce_max_1 = reduce_max(axes = reduce_max_1_axes_0, keep_dims = reduce_max_1_keep_dims_0, x = var_817_promoted)[name = tensor("reduce_max_1")]; tensor var_803_promoted = cast(dtype = var_803_promoted_dtype_0, x = var_803)[name = tensor("cast_222")]; tensor name_relevant_1 = mul(x = var_803_promoted, y = reduce_max_1)[name = tensor("name_relevant_1")]; tensor var_826_promoted = const()[name = tensor("op_826_promoted"), val = tensor(0x1p+0)]; tensor var_828 = sub(x = var_826_promoted, y = name_relevant_1)[name = tensor("op_828")]; tensor var_829 = const()[name = tensor("op_829"), val = tensor(0x1.99999ap-4)]; tensor var_830 = mul(x = var_828, y = var_829)[name = tensor("op_830")]; tensor name_relevant_3 = add(x = var_830, y = name_relevant_1)[name = tensor("name_relevant_3")]; tensor reduce_max_2_axes_0 = const()[name = tensor("reduce_max_2_axes_0"), val = tensor([1])]; tensor reduce_max_2_keep_dims_0 = const()[name = tensor("reduce_max_2_keep_dims_0"), val = tensor(false)]; tensor reduce_max_2 = reduce_max(axes = reduce_max_2_axes_0, keep_dims = reduce_max_2_keep_dims_0, x = name_relevant_3)[name = tensor("reduce_max_2")]; tensor time_context_1_begin_0 = const()[name = tensor("time_context_1_begin_0"), val = tensor([0])]; tensor time_context_1_end_0 = const()[name = tensor("time_context_1_end_0"), val = tensor([1])]; tensor time_context_1_end_mask_0 = const()[name = tensor("time_context_1_end_mask_0"), val = tensor([false])]; tensor time_context_1_squeeze_mask_0 = const()[name = tensor("time_context_1_squeeze_mask_0"), val = tensor([true])]; tensor time_context_1 = slice_by_index(begin = time_context_1_begin_0, end = time_context_1_end_0, end_mask = time_context_1_end_mask_0, squeeze_mask = time_context_1_squeeze_mask_0, x = context)[name = tensor("time_context_1")]; tensor location_context_1_begin_0 = const()[name = tensor("location_context_1_begin_0"), val = tensor([1])]; tensor location_context_1_end_0 = const()[name = tensor("location_context_1_end_0"), val = tensor([4])]; tensor location_context_1_end_mask_0 = const()[name = tensor("location_context_1_end_mask_0"), val = tensor([false])]; tensor location_context_1 = slice_by_index(begin = location_context_1_begin_0, end = location_context_1_end_0, end_mask = location_context_1_end_mask_0, x = context)[name = tensor("location_context_1")]; tensor freq_context_1_begin_0 = const()[name = tensor("freq_context_1_begin_0"), val = tensor([4])]; tensor freq_context_1_end_0 = const()[name = tensor("freq_context_1_end_0"), val = tensor([6])]; tensor freq_context_1_end_mask_0 = const()[name = tensor("freq_context_1_end_mask_0"), val = tensor([true])]; tensor freq_context_1 = slice_by_index(begin = freq_context_1_begin_0, end = freq_context_1_end_0, end_mask = freq_context_1_end_mask_0, x = context)[name = tensor("freq_context_1")]; tensor var_852_perm_0 = const()[name = tensor("op_852_perm_0"), val = tensor([1, 0])]; tensor time_context_feedback_1_begin_0 = const()[name = tensor("time_context_feedback_1_begin_0"), val = tensor([0, 0])]; tensor time_context_feedback_1_end_0 = const()[name = tensor("time_context_feedback_1_end_0"), val = tensor([1, 1000])]; tensor time_context_feedback_1_end_mask_0 = const()[name = tensor("time_context_feedback_1_end_mask_0"), val = tensor([false, true])]; tensor time_context_feedback_1_squeeze_mask_0 = const()[name = tensor("time_context_feedback_1_squeeze_mask_0"), val = tensor([true, false])]; tensor var_852 = transpose(perm = var_852_perm_0, x = context_feedback_17)[name = tensor("transpose_40")]; tensor time_context_feedback_1 = slice_by_index(begin = time_context_feedback_1_begin_0, end = time_context_feedback_1_end_0, end_mask = time_context_feedback_1_end_mask_0, squeeze_mask = time_context_feedback_1_squeeze_mask_0, x = var_852)[name = tensor("time_context_feedback_1")]; tensor var_858_perm_0 = const()[name = tensor("op_858_perm_0"), val = tensor([1, 0])]; tensor not_padded_time_1_begin_0 = const()[name = tensor("not_padded_time_1_begin_0"), val = tensor([0, 0])]; tensor not_padded_time_1_end_0 = const()[name = tensor("not_padded_time_1_end_0"), val = tensor([1, 1000])]; tensor not_padded_time_1_end_mask_0 = const()[name = tensor("not_padded_time_1_end_mask_0"), val = tensor([false, true])]; tensor not_padded_time_1_squeeze_mask_0 = const()[name = tensor("not_padded_time_1_squeeze_mask_0"), val = tensor([true, false])]; tensor var_858 = transpose(perm = var_858_perm_0, x = not_padded_feedback_1)[name = tensor("transpose_39")]; tensor not_padded_time_1 = slice_by_index(begin = not_padded_time_1_begin_0, end = not_padded_time_1_end_0, end_mask = not_padded_time_1_end_mask_0, squeeze_mask = not_padded_time_1_squeeze_mask_0, x = var_858)[name = tensor("not_padded_time_1")]; tensor var_869_begin_0 = const()[name = tensor("op_869_begin_0"), val = tensor([1, 0])]; tensor var_869_end_0 = const()[name = tensor("op_869_end_0"), val = tensor([4, 1000])]; tensor var_869_end_mask_0 = const()[name = tensor("op_869_end_mask_0"), val = tensor([false, true])]; tensor var_869 = slice_by_index(begin = var_869_begin_0, end = var_869_end_0, end_mask = var_869_end_mask_0, x = var_852)[name = tensor("op_869")]; tensor location_context_feedback_1_perm_0 = const()[name = tensor("location_context_feedback_1_perm_0"), val = tensor([1, 0])]; tensor var_880_begin_0 = const()[name = tensor("op_880_begin_0"), val = tensor([1, 0])]; tensor var_880_end_0 = const()[name = tensor("op_880_end_0"), val = tensor([4, 1000])]; tensor var_880_end_mask_0 = const()[name = tensor("op_880_end_mask_0"), val = tensor([false, true])]; tensor var_880 = slice_by_index(begin = var_880_begin_0, end = var_880_end_0, end_mask = var_880_end_mask_0, x = var_858)[name = tensor("op_880")]; tensor not_padded_location_1_perm_0 = const()[name = tensor("not_padded_location_1_perm_0"), val = tensor([1, 0])]; tensor var_891_begin_0 = const()[name = tensor("op_891_begin_0"), val = tensor([4, 0])]; tensor var_891_end_0 = const()[name = tensor("op_891_end_0"), val = tensor([6, 1000])]; tensor var_891_end_mask_0 = const()[name = tensor("op_891_end_mask_0"), val = tensor([true, true])]; tensor var_891 = slice_by_index(begin = var_891_begin_0, end = var_891_end_0, end_mask = var_891_end_mask_0, x = var_852)[name = tensor("op_891")]; tensor freq_context_feedback_1_perm_0 = const()[name = tensor("freq_context_feedback_1_perm_0"), val = tensor([1, 0])]; tensor var_902_begin_0 = const()[name = tensor("op_902_begin_0"), val = tensor([4, 0])]; tensor var_902_end_0 = const()[name = tensor("op_902_end_0"), val = tensor([6, 1000])]; tensor var_902_end_mask_0 = const()[name = tensor("op_902_end_mask_0"), val = tensor([true, true])]; tensor var_902 = slice_by_index(begin = var_902_begin_0, end = var_902_end_0, end_mask = var_902_end_mask_0, x = var_858)[name = tensor("op_902")]; tensor not_padded_freq_1_perm_0 = const()[name = tensor("not_padded_freq_1_perm_0"), val = tensor([1, 0])]; tensor var_907 = sub(x = time_context_feedback_1, y = time_context_1)[name = tensor("op_907")]; tensor var_908 = abs(x = var_907)[name = tensor("op_908")]; tensor similarity_time_1 = mul(x = var_908, y = not_padded_time_1)[name = tensor("similarity_time_1")]; tensor freq_context_feedback_1 = transpose(perm = freq_context_feedback_1_perm_0, x = var_891)[name = tensor("transpose_36")]; tensor var_911 = sub(x = freq_context_feedback_1, y = freq_context_1)[name = tensor("op_911")]; tensor not_padded_freq_1 = transpose(perm = not_padded_freq_1_perm_0, x = var_902)[name = tensor("transpose_35")]; tensor input_1 = mul(x = var_911, y = not_padded_freq_1)[name = tensor("input_1")]; tensor var_915 = const()[name = tensor("op_915"), val = tensor([1])]; tensor var_916 = const()[name = tensor("op_916"), val = tensor(false)]; tensor similarity_freq_1 = reduce_l2_norm(axes = var_915, keep_dims = var_916, x = input_1)[name = tensor("similarity_freq_1")]; tensor location_context_feedback_1 = transpose(perm = location_context_feedback_1_perm_0, x = var_869)[name = tensor("transpose_38")]; tensor var_920 = sub(x = location_context_feedback_1, y = location_context_1)[name = tensor("op_920")]; tensor not_padded_location_1 = transpose(perm = not_padded_location_1_perm_0, x = var_880)[name = tensor("transpose_37")]; tensor input_3 = mul(x = var_920, y = not_padded_location_1)[name = tensor("input_3")]; tensor var_924 = const()[name = tensor("op_924"), val = tensor([1])]; tensor var_925 = const()[name = tensor("op_925"), val = tensor(false)]; tensor similarity_location_1 = reduce_l2_norm(axes = var_924, keep_dims = var_925, x = input_3)[name = tensor("similarity_location_1")]; tensor var_928 = const()[name = tensor("op_928"), val = tensor(0x1p-1)]; tensor var_929 = equal(x = candidate_interactions, y = var_928)[name = tensor("op_929")]; tensor var_929_promoted_dtype_0 = const()[name = tensor("op_929_promoted_dtype_0"), val = tensor("fp32")]; tensor var_936_axes_0 = const()[name = tensor("op_936_axes_0"), val = tensor([1])]; tensor var_936_keep_dims_0 = const()[name = tensor("op_936_keep_dims_0"), val = tensor(false)]; tensor var_929_promoted = cast(dtype = var_929_promoted_dtype_0, x = var_929)[name = tensor("cast_220")]; tensor var_936 = reduce_sum(axes = var_936_axes_0, keep_dims = var_936_keep_dims_0, x = var_929_promoted)[name = tensor("op_936")]; tensor var_937_promoted = const()[name = tensor("op_937_promoted"), val = tensor(0x1p+0)]; tensor var_939 = sub(x = var_937_promoted, y = var_936)[name = tensor("op_939")]; tensor var_940 = mul(x = not_padded_time_1, y = var_939)[name = tensor("op_940")]; tensor n_time_1_axes_0 = const()[name = tensor("n_time_1_axes_0"), val = tensor([0])]; tensor n_time_1_keep_dims_0 = const()[name = tensor("n_time_1_keep_dims_0"), val = tensor(false)]; tensor n_time_1 = reduce_sum(axes = n_time_1_axes_0, keep_dims = n_time_1_keep_dims_0, x = var_940)[name = tensor("n_time_1")]; tensor var_946 = const()[name = tensor("op_946"), val = tensor(0x1.c1872cp-3)]; tensor var_947 = pow(x = n_time_1, y = var_946)[name = tensor("op_947")]; tensor var_948_epsilon_0 = const()[name = tensor("op_948_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_948 = inverse(epsilon = var_948_epsilon_0, x = var_947)[name = tensor("op_948")]; tensor var_949 = const()[name = tensor("op_949"), val = tensor(0x1.95e9e4p+2)]; tensor bw_time_1 = mul(x = var_948, y = var_949)[name = tensor("bw_time_1")]; tensor var_955_axes_0 = const()[name = tensor("op_955_axes_0"), val = tensor([1])]; tensor var_955_keep_dims_0 = const()[name = tensor("op_955_keep_dims_0"), val = tensor(false)]; tensor var_955 = reduce_sum(axes = var_955_axes_0, keep_dims = var_955_keep_dims_0, x = not_padded_freq_1)[name = tensor("op_955")]; tensor var_956 = const()[name = tensor("op_956"), val = tensor(0x0p+0)]; tensor var_957 = greater(x = var_955, y = var_956)[name = tensor("op_957")]; tensor var_957_promoted_dtype_0 = const()[name = tensor("op_957_promoted_dtype_0"), val = tensor("fp32")]; tensor var_957_promoted = cast(dtype = var_957_promoted_dtype_0, x = var_957)[name = tensor("cast_219")]; tensor var_970 = mul(x = var_957_promoted, y = var_939)[name = tensor("op_970")]; tensor n_freq_1_keep_dims_0 = const()[name = tensor("n_freq_1_keep_dims_0"), val = tensor(false)]; tensor n_freq_1 = reduce_sum(keep_dims = n_freq_1_keep_dims_0, x = var_970)[name = tensor("n_freq_1")]; tensor var_973 = const()[name = tensor("op_973"), val = tensor(0x1p-1)]; tensor var_974 = pow(x = n_freq_1, y = var_973)[name = tensor("op_974")]; tensor var_975_epsilon_0 = const()[name = tensor("op_975_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_975 = inverse(epsilon = var_975_epsilon_0, x = var_974)[name = tensor("op_975")]; tensor var_976 = const()[name = tensor("op_976"), val = tensor(0x1.f80ac8p+2)]; tensor bw_freq_1 = mul(x = var_975, y = var_976)[name = tensor("bw_freq_1")]; tensor var_982_axes_0 = const()[name = tensor("op_982_axes_0"), val = tensor([1])]; tensor var_982_keep_dims_0 = const()[name = tensor("op_982_keep_dims_0"), val = tensor(false)]; tensor var_982 = reduce_sum(axes = var_982_axes_0, keep_dims = var_982_keep_dims_0, x = not_padded_location_1)[name = tensor("op_982")]; tensor var_983 = const()[name = tensor("op_983"), val = tensor(0x0p+0)]; tensor var_984 = greater(x = var_982, y = var_983)[name = tensor("op_984")]; tensor var_984_promoted_dtype_0 = const()[name = tensor("op_984_promoted_dtype_0"), val = tensor("fp32")]; tensor var_984_promoted = cast(dtype = var_984_promoted_dtype_0, x = var_984)[name = tensor("cast_218")]; tensor var_997 = mul(x = var_984_promoted, y = var_939)[name = tensor("op_997")]; tensor n_location_1_keep_dims_0 = const()[name = tensor("n_location_1_keep_dims_0"), val = tensor(false)]; tensor n_location_1 = reduce_sum(keep_dims = n_location_1_keep_dims_0, x = var_997)[name = tensor("n_location_1")]; tensor var_1000 = const()[name = tensor("op_1000"), val = tensor(0x1.e1583ep-2)]; tensor var_1001 = pow(x = n_location_1, y = var_1000)[name = tensor("op_1001")]; tensor var_1002_epsilon_0 = const()[name = tensor("op_1002_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_1002 = inverse(epsilon = var_1002_epsilon_0, x = var_1001)[name = tensor("op_1002")]; tensor var_1003 = const()[name = tensor("op_1003"), val = tensor(0x1.292be6p+4)]; tensor bw_location_1 = mul(x = var_1002, y = var_1003)[name = tensor("bw_location_1")]; tensor var_1006_axes_0 = const()[name = tensor("op_1006_axes_0"), val = tensor([-1])]; tensor var_1006 = expand_dims(axes = var_1006_axes_0, x = var_464)[name = tensor("op_1006")]; tensor var_1008_axes_0 = const()[name = tensor("op_1008_axes_0"), val = tensor([-1])]; tensor var_1008 = expand_dims(axes = var_1008_axes_0, x = var_1006)[name = tensor("op_1008")]; tensor var_1009_promoted = const()[name = tensor("op_1009_promoted"), val = tensor(-0x1.f4p+9)]; tensor var_1010 = greater(x = x_3, y = var_1009_promoted)[name = tensor("op_1010")]; tensor var_1010_promoted_dtype_0 = const()[name = tensor("op_1010_promoted_dtype_0"), val = tensor("int32")]; tensor var_1014_axes_0 = const()[name = tensor("op_1014_axes_0"), val = tensor([0])]; tensor var_1010_promoted = cast(dtype = var_1010_promoted_dtype_0, x = var_1010)[name = tensor("cast_217")]; tensor var_1014 = expand_dims(axes = var_1014_axes_0, x = var_1010_promoted)[name = tensor("op_1014")]; tensor var_1014_promoted_dtype_0 = const()[name = tensor("op_1014_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1014_promoted = cast(dtype = var_1014_promoted_dtype_0, x = var_1014)[name = tensor("cast_216")]; tensor expanded = mul(x = var_1008, y = var_1014_promoted)[name = tensor("expanded")]; tensor pos_align_candidates = mul(x = var_464, y = alignment_feedback)[name = tensor("pos_align_candidates")]; tensor var_1017_promoted = const()[name = tensor("op_1017_promoted"), val = tensor(0x0p+0)]; tensor var_1018 = greater(x = alignment_feedback, y = var_1017_promoted)[name = tensor("op_1018")]; tensor var_1018_promoted_dtype_0 = const()[name = tensor("op_1018_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1018_promoted = cast(dtype = var_1018_promoted_dtype_0, x = var_1018)[name = tensor("cast_215")]; tensor alignment_scaling_1 = mul(x = var_1018_promoted, y = alignment_feedback)[name = tensor("alignment_scaling_1")]; tensor var_1020 = real_div(x = similarity_location_1, y = bw_location_1)[name = tensor("op_1020")]; tensor var_1021_promoted = const()[name = tensor("op_1021_promoted"), val = tensor(0x1p+1)]; tensor var_1022 = pow(x = var_1020, y = var_1021_promoted)[name = tensor("op_1022")]; tensor var_1023_promoted = const()[name = tensor("op_1023_promoted"), val = tensor(-0x1p+0)]; tensor var_1024 = mul(x = var_1022, y = var_1023_promoted)[name = tensor("op_1024")]; tensor location_score_1 = exp(x = var_1024)[name = tensor("location_score_1")]; tensor var_1026 = real_div(x = similarity_time_1, y = bw_time_1)[name = tensor("op_1026")]; tensor var_1027_promoted = const()[name = tensor("op_1027_promoted"), val = tensor(0x1p+1)]; tensor var_1028 = pow(x = var_1026, y = var_1027_promoted)[name = tensor("op_1028")]; tensor var_1029_promoted = const()[name = tensor("op_1029_promoted"), val = tensor(-0x1p+0)]; tensor var_1030 = mul(x = var_1028, y = var_1029_promoted)[name = tensor("op_1030")]; tensor time_score_1 = exp(x = var_1030)[name = tensor("time_score_1")]; tensor var_1032 = real_div(x = similarity_freq_1, y = bw_freq_1)[name = tensor("op_1032")]; tensor var_1033_promoted = const()[name = tensor("op_1033_promoted"), val = tensor(0x1p+1)]; tensor var_1034 = pow(x = var_1032, y = var_1033_promoted)[name = tensor("op_1034")]; tensor var_1035_promoted = const()[name = tensor("op_1035_promoted"), val = tensor(-0x1p+0)]; tensor var_1036 = mul(x = var_1034, y = var_1035_promoted)[name = tensor("op_1036")]; tensor freq_score_1 = exp(x = var_1036)[name = tensor("freq_score_1")]; tensor var_1038 = mul(x = alignment_scaling_1, y = time_score_1)[name = tensor("op_1038")]; tensor var_1039 = mul(x = var_1038, y = freq_score_1)[name = tensor("op_1039")]; tensor var_1040 = mul(x = var_1039, y = location_score_1)[name = tensor("op_1040")]; tensor candidate_psuedo_counts_1 = mul(x = var_1040, y = reduce_max_2)[name = tensor("candidate_psuedo_counts_1")]; tensor var_1043_axes_0 = const()[name = tensor("op_1043_axes_0"), val = tensor([-1])]; tensor var_1043 = expand_dims(axes = var_1043_axes_0, x = pos_align_candidates)[name = tensor("op_1043")]; tensor var_1045_axes_0 = const()[name = tensor("op_1045_axes_0"), val = tensor([-1])]; tensor var_1045 = expand_dims(axes = var_1045_axes_0, x = var_1043)[name = tensor("op_1045")]; tensor var_1052 = mul(x = var_1045, y = var_1014_promoted)[name = tensor("op_1052")]; tensor var_1053_promoted = const()[name = tensor("op_1053_promoted"), val = tensor(0x0p+0)]; tensor mask_7 = greater(x = var_1052, y = var_1053_promoted)[name = tensor("mask_7")]; tensor mask_7_promoted_dtype_0 = const()[name = tensor("mask_7_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_7_promoted = cast(dtype = mask_7_promoted_dtype_0, x = mask_7)[name = tensor("cast_214")]; tensor var_1055 = mul(x = expanded, y = mask_7_promoted)[name = tensor("op_1055")]; tensor var_1056 = equal(x = var_1055, y = x_3)[name = tensor("op_1056")]; tensor var_1060_axes_0 = const()[name = tensor("op_1060_axes_0"), val = tensor([-1])]; tensor var_1060 = expand_dims(axes = var_1060_axes_0, x = candidate_psuedo_counts_1)[name = tensor("op_1060")]; tensor var_1062_axes_0 = const()[name = tensor("op_1062_axes_0"), val = tensor([-1])]; tensor var_1062 = expand_dims(axes = var_1062_axes_0, x = var_1060)[name = tensor("op_1062")]; tensor expanded_counts_1 = mul(x = var_1062, y = var_1014_promoted)[name = tensor("expanded_counts_1")]; tensor match_1_promoted_dtype_0 = const()[name = tensor("match_1_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1056_to_fp32 = cast(dtype = match_1_promoted_dtype_0, x = var_1056)[name = tensor("cast_213")]; tensor var_1070 = mul(x = expanded_counts_1, y = var_1056_to_fp32)[name = tensor("op_1070")]; tensor var_1075_axes_0 = const()[name = tensor("op_1075_axes_0"), val = tensor([0])]; tensor var_1075_keep_dims_0 = const()[name = tensor("op_1075_keep_dims_0"), val = tensor(false)]; tensor var_1075 = reduce_sum(axes = var_1075_axes_0, keep_dims = var_1075_keep_dims_0, x = var_1070)[name = tensor("op_1075")]; tensor var_1076_promoted = const()[name = tensor("op_1076_promoted"), val = tensor(0x0p+0)]; tensor var_1077 = less(x = alignment_feedback, y = var_1076_promoted)[name = tensor("op_1077")]; tensor var_1077_promoted_dtype_0 = const()[name = tensor("op_1077_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1077_promoted = cast(dtype = var_1077_promoted_dtype_0, x = var_1077)[name = tensor("cast_212")]; tensor alignment_scaling = mul(x = var_1077_promoted, y = alignment_feedback)[name = tensor("alignment_scaling")]; tensor var_1098 = add(x = similarity_time_1, y = similarity_freq_1)[name = tensor("op_1098")]; tensor var_1100 = add(x = var_1098, y = similarity_location_1)[name = tensor("op_1100")]; tensor _inversed_v_y_0 = const()[name = tensor("_inversed_v_y_0"), val = tensor(0x1.555556p-2)]; tensor _inversed_v = mul(x = var_1100, y = _inversed_v_y_0)[name = tensor("_inversed_v")]; tensor var_1103 = mul(x = alignment_scaling, y = time_score_1)[name = tensor("op_1103")]; tensor var_1104 = mul(x = var_1103, y = freq_score_1)[name = tensor("op_1104")]; tensor var_1105 = mul(x = var_1104, y = location_score_1)[name = tensor("op_1105")]; tensor var_1106 = mul(x = var_1105, y = reduce_max_2)[name = tensor("op_1106")]; tensor candidate_psuedo_counts = abs(x = var_1106)[name = tensor("candidate_psuedo_counts")]; tensor var_1119_promoted = const()[name = tensor("op_1119_promoted"), val = tensor(0x0p+0)]; tensor mask_9 = less(x = var_1052, y = var_1119_promoted)[name = tensor("mask_9")]; tensor mask_9_promoted_dtype_0 = const()[name = tensor("mask_9_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_9_promoted = cast(dtype = mask_9_promoted_dtype_0, x = mask_9)[name = tensor("cast_211")]; tensor var_1121 = mul(x = expanded, y = mask_9_promoted)[name = tensor("op_1121")]; tensor var_1122 = equal(x = var_1121, y = x_3)[name = tensor("op_1122")]; tensor var_1126_axes_0 = const()[name = tensor("op_1126_axes_0"), val = tensor([-1])]; tensor var_1126 = expand_dims(axes = var_1126_axes_0, x = candidate_psuedo_counts)[name = tensor("op_1126")]; tensor var_1128_axes_0 = const()[name = tensor("op_1128_axes_0"), val = tensor([-1])]; tensor var_1128 = expand_dims(axes = var_1128_axes_0, x = var_1126)[name = tensor("op_1128")]; tensor expanded_counts = mul(x = var_1128, y = var_1014_promoted)[name = tensor("expanded_counts")]; tensor match_promoted_dtype_0 = const()[name = tensor("match_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1122_to_fp32 = cast(dtype = match_promoted_dtype_0, x = var_1122)[name = tensor("cast_210")]; tensor var_1136 = mul(x = expanded_counts, y = var_1122_to_fp32)[name = tensor("op_1136")]; tensor var_1141_axes_0 = const()[name = tensor("op_1141_axes_0"), val = tensor([0])]; tensor var_1141_keep_dims_0 = const()[name = tensor("op_1141_keep_dims_0"), val = tensor(false)]; tensor var_1141 = reduce_sum(axes = var_1141_axes_0, keep_dims = var_1141_keep_dims_0, x = var_1136)[name = tensor("op_1141")]; tensor var_1142 = const()[name = tensor("op_1142"), val = tensor(0x1.9cedbp-1)]; tensor var_1143 = mul(x = var_1075, y = var_1142)[name = tensor("op_1143")]; tensor var_1144 = const()[name = tensor("op_1144"), val = tensor(0x1.9cedbp-1)]; tensor var_1145 = mul(x = var_1141, y = var_1144)[name = tensor("op_1145")]; tensor var_1152_promoted = const()[name = tensor("op_1152_promoted"), val = tensor(0x1p+2)]; tensor var_1153 = not_equal(x = var_80, y = var_1152_promoted)[name = tensor("op_1153")]; tensor var_1153_promoted_dtype_0 = const()[name = tensor("op_1153_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1153_promoted = cast(dtype = var_1153_promoted_dtype_0, x = var_1153)[name = tensor("cast_209")]; tensor var_1154 = mul(x = var_1143, y = var_1153_promoted)[name = tensor("op_1154")]; tensor var_1155 = mul(x = var_1145, y = var_1153_promoted)[name = tensor("op_1155")]; tensor var_1163_begin_0 = const()[name = tensor("op_1163_begin_0"), val = tensor([0, 0])]; tensor var_1163_end_0 = const()[name = tensor("op_1163_end_0"), val = tensor([10, 1])]; tensor var_1163_end_mask_0 = const()[name = tensor("op_1163_end_mask_0"), val = tensor([true, false])]; tensor var_1163_squeeze_mask_0 = const()[name = tensor("op_1163_squeeze_mask_0"), val = tensor([false, true])]; tensor var_1163 = slice_by_index(begin = var_1163_begin_0, end = var_1163_end_0, end_mask = var_1163_end_mask_0, squeeze_mask = var_1163_squeeze_mask_0, x = similarityScores)[name = tensor("op_1163")]; tensor var_1165_axes_0 = const()[name = tensor("op_1165_axes_0"), val = tensor([-1])]; tensor var_1165 = expand_dims(axes = var_1165_axes_0, x = var_1163)[name = tensor("op_1165")]; tensor a_axes_0 = const()[name = tensor("a_axes_0"), val = tensor([-1])]; tensor a = expand_dims(axes = a_axes_0, x = var_1165)[name = tensor("a")]; tensor var_1175_begin_0 = const()[name = tensor("op_1175_begin_0"), val = tensor([0, 1])]; tensor var_1175_end_0 = const()[name = tensor("op_1175_end_0"), val = tensor([10, 2])]; tensor var_1175_end_mask_0 = const()[name = tensor("op_1175_end_mask_0"), val = tensor([true, false])]; tensor var_1175_squeeze_mask_0 = const()[name = tensor("op_1175_squeeze_mask_0"), val = tensor([false, true])]; tensor var_1175 = slice_by_index(begin = var_1175_begin_0, end = var_1175_end_0, end_mask = var_1175_end_mask_0, squeeze_mask = var_1175_squeeze_mask_0, x = similarityScores)[name = tensor("op_1175")]; tensor var_1177_axes_0 = const()[name = tensor("op_1177_axes_0"), val = tensor([-1])]; tensor var_1177 = expand_dims(axes = var_1177_axes_0, x = var_1175)[name = tensor("op_1177")]; tensor b_axes_0 = const()[name = tensor("b_axes_0"), val = tensor([-1])]; tensor b = expand_dims(axes = b_axes_0, x = var_1177)[name = tensor("b")]; tensor var_1187_begin_0 = const()[name = tensor("op_1187_begin_0"), val = tensor([0, 2])]; tensor var_1187_end_0 = const()[name = tensor("op_1187_end_0"), val = tensor([10, 3])]; tensor var_1187_end_mask_0 = const()[name = tensor("op_1187_end_mask_0"), val = tensor([true, false])]; tensor var_1187_squeeze_mask_0 = const()[name = tensor("op_1187_squeeze_mask_0"), val = tensor([false, true])]; tensor var_1187 = slice_by_index(begin = var_1187_begin_0, end = var_1187_end_0, end_mask = var_1187_end_mask_0, squeeze_mask = var_1187_squeeze_mask_0, x = similarityScores)[name = tensor("op_1187")]; tensor var_1188_promoted = const()[name = tensor("op_1188_promoted"), val = tensor(0x0p+0)]; tensor var_1189 = greater_equal(x = var_1187, y = var_1188_promoted)[name = tensor("op_1189")]; tensor var_1189_promoted_dtype_0 = const()[name = tensor("op_1189_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1189_promoted = cast(dtype = var_1189_promoted_dtype_0, x = var_1189)[name = tensor("cast_208")]; tensor var_1198 = mul(x = var_1189_promoted, y = var_1187)[name = tensor("op_1198")]; tensor var_1215 = not_equal(x = var_1163, y = var_1175)[name = tensor("op_1215")]; tensor var_1215_promoted_dtype_0 = const()[name = tensor("op_1215_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1215_promoted = cast(dtype = var_1215_promoted_dtype_0, x = var_1215)[name = tensor("cast_207")]; tensor scores = mul(x = var_1198, y = var_1215_promoted)[name = tensor("scores")]; tensor var_1219_promoted = const()[name = tensor("op_1219_promoted"), val = tensor(0x0p+0)]; tensor var_1220 = mul(x = a, y = var_1219_promoted)[name = tensor("op_1220")]; tensor candidates_expanded = add(x = var_808, y = var_1220)[name = tensor("candidates_expanded")]; tensor var_1223 = equal(x = candidates_expanded, y = a)[name = tensor("op_1223")]; tensor var_1223_promoted_dtype_0 = const()[name = tensor("op_1223_promoted_dtype_0"), val = tensor("int32")]; tensor var_1227 = equal(x = candidates_expanded, y = b)[name = tensor("op_1227")]; tensor var_1227_promoted_dtype_0 = const()[name = tensor("op_1227_promoted_dtype_0"), val = tensor("int32")]; tensor var_1232_axes_0 = const()[name = tensor("op_1232_axes_0"), val = tensor([-1])]; tensor var_1232 = expand_dims(axes = var_1232_axes_0, x = scores)[name = tensor("op_1232")]; tensor var_1234_axes_0 = const()[name = tensor("op_1234_axes_0"), val = tensor([-1])]; tensor var_1234 = expand_dims(axes = var_1234_axes_0, x = var_1232)[name = tensor("op_1234")]; tensor has_a_promoted_dtype_0 = const()[name = tensor("has_a_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1223_to_fp32 = cast(dtype = has_a_promoted_dtype_0, x = var_1223)[name = tensor("cast_204")]; tensor a_scores = mul(x = var_1223_to_fp32, y = var_1234)[name = tensor("a_scores")]; tensor has_b_promoted_dtype_0 = const()[name = tensor("has_b_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1227_to_fp32 = cast(dtype = has_b_promoted_dtype_0, x = var_1227)[name = tensor("cast_203")]; tensor b_scores = mul(x = var_1227_to_fp32, y = var_1234)[name = tensor("b_scores")]; tensor similarity_conc = add(x = a_scores, y = b_scores)[name = tensor("similarity_conc")]; tensor var_1227_promoted = cast(dtype = var_1227_promoted_dtype_0, x = var_1227)[name = tensor("cast_205")]; tensor var_1223_promoted = cast(dtype = var_1223_promoted_dtype_0, x = var_1223)[name = tensor("cast_206")]; tensor var_1244 = add(x = var_1223_promoted, y = var_1227_promoted)[name = tensor("op_1244")]; tensor var_1249_axes_0 = const()[name = tensor("op_1249_axes_0"), val = tensor([-1])]; tensor var_1249_keep_dims_0 = const()[name = tensor("op_1249_keep_dims_0"), val = tensor(false)]; tensor var_1249 = reduce_sum(axes = var_1249_axes_0, keep_dims = var_1249_keep_dims_0, x = var_1244)[name = tensor("op_1249")]; tensor var_1250 = const()[name = tensor("op_1250"), val = tensor(1)]; tensor var_1251 = greater(x = var_1249, y = var_1250)[name = tensor("op_1251")]; tensor var_1251_promoted_dtype_0 = const()[name = tensor("op_1251_promoted_dtype_0"), val = tensor("int32")]; tensor var_1256_axes_0 = const()[name = tensor("op_1256_axes_0"), val = tensor([-1])]; tensor var_1251_promoted = cast(dtype = var_1251_promoted_dtype_0, x = var_1251)[name = tensor("cast_202")]; tensor var_1256 = expand_dims(axes = var_1256_axes_0, x = var_1251_promoted)[name = tensor("op_1256")]; tensor var_1256_promoted_dtype_0 = const()[name = tensor("op_1256_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1256_promoted = cast(dtype = var_1256_promoted_dtype_0, x = var_1256)[name = tensor("cast_201")]; tensor var_1257 = mul(x = similarity_conc, y = var_1256_promoted)[name = tensor("op_1257")]; tensor var_1262_axes_0 = const()[name = tensor("op_1262_axes_0"), val = tensor([0])]; tensor var_1262_keep_dims_0 = const()[name = tensor("op_1262_keep_dims_0"), val = tensor(false)]; tensor var_1262 = reduce_sum(axes = var_1262_axes_0, keep_dims = var_1262_keep_dims_0, x = var_1257)[name = tensor("op_1262")]; tensor var_1263_promoted = const()[name = tensor("op_1263_promoted"), val = tensor(0x0p+0)]; tensor var_1264 = mul(x = x_3, y = var_1263_promoted)[name = tensor("op_1264")]; tensor var_1266_promoted = const()[name = tensor("op_1266_promoted"), val = tensor(0x1p+0)]; tensor var_1267 = add(x = var_1264, y = var_1266_promoted)[name = tensor("op_1267")]; tensor var_1272_axes_0 = const()[name = tensor("op_1272_axes_0"), val = tensor([0])]; tensor var_1272_keep_dims_0 = const()[name = tensor("op_1272_keep_dims_0"), val = tensor(false)]; tensor var_1272 = reduce_sum(axes = var_1272_axes_0, keep_dims = var_1272_keep_dims_0, x = var_1267)[name = tensor("op_1272")]; tensor var_1275_begin_0 = const()[name = tensor("op_1275_begin_0"), val = tensor([0])]; tensor var_1275_end_0 = const()[name = tensor("op_1275_end_0"), val = tensor([1])]; tensor var_1275_end_mask_0 = const()[name = tensor("op_1275_end_mask_0"), val = tensor([false])]; tensor var_1275_squeeze_mask_0 = const()[name = tensor("op_1275_squeeze_mask_0"), val = tensor([true])]; tensor var_1275 = slice_by_index(begin = var_1275_begin_0, end = var_1275_end_0, end_mask = var_1275_end_mask_0, squeeze_mask = var_1275_squeeze_mask_0, x = var_1272)[name = tensor("op_1275")]; tensor var_1276 = const()[name = tensor("op_1276"), val = tensor(0x1.f7942ep-2)]; tensor var_1277 = pow(x = var_1275, y = var_1276)[name = tensor("op_1277")]; tensor var_1278 = const()[name = tensor("op_1278"), val = tensor(0x1.81c8b2p-1)]; tensor var_1279 = mul(x = var_1277, y = var_1278)[name = tensor("op_1279")]; tensor var_1281 = const()[name = tensor("op_1281"), val = tensor(-0x1.6b9c96p-1)]; tensor var_1282 = add(x = var_1279, y = var_1281)[name = tensor("op_1282")]; tensor var_1285 = mul(x = var_1282, y = var_1262)[name = tensor("op_1285")]; tensor var_1291 = const()[name = tensor("op_1291"), val = tensor(0x0p+0)]; tensor var_1292 = mul(x = var_162, y = var_1291)[name = tensor("op_1292")]; tensor shape_18 = const()[name = tensor("shape_18"), val = tensor([9, 50, 15])]; tensor reshape_16 = const()[name = tensor("reshape_16"), val = tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749])]; tensor reshape_17_shape_0 = const()[name = tensor("reshape_17_shape_0"), val = tensor([-1])]; tensor reshape_17 = reshape(shape = reshape_17_shape_0, x = var_1292)[name = tensor("reshape_17")]; tensor reshape_18_shape_0 = const()[name = tensor("reshape_18_shape_0"), val = tensor([-1])]; tensor reshape_18 = reshape(shape = reshape_18_shape_0, x = tuples)[name = tensor("reshape_18")]; tensor scatter_3_mode_0 = const()[name = tensor("scatter_3_mode_0"), val = tensor("update")]; tensor scatter_3_axis_0 = const()[name = tensor("scatter_3_axis_0"), val = tensor(0)]; tensor scatter_3 = scatter(axis = scatter_3_axis_0, data = reshape_18, indices = reshape_16, mode = scatter_3_mode_0, updates = reshape_17)[name = tensor("scatter_3")]; tensor reshape_19 = reshape(shape = shape_18, x = scatter_3)[name = tensor("reshape_19")]; tensor var_1300_begin_0 = const()[name = tensor("op_1300_begin_0"), val = tensor([8, 0, 0])]; tensor var_1300_end_0 = const()[name = tensor("op_1300_end_0"), val = tensor([9, 50, 15])]; tensor var_1300_end_mask_0 = const()[name = tensor("op_1300_end_mask_0"), val = tensor([false, true, true])]; tensor var_1300_squeeze_mask_0 = const()[name = tensor("op_1300_squeeze_mask_0"), val = tensor([true, false, false])]; tensor var_1300 = slice_by_index(begin = var_1300_begin_0, end = var_1300_end_0, end_mask = var_1300_end_mask_0, squeeze_mask = var_1300_squeeze_mask_0, x = reshape_19)[name = tensor("op_1300")]; tensor var_1301_promoted = const()[name = tensor("op_1301_promoted"), val = tensor(0x0p+0)]; tensor var_1302 = mul(x = var_1300, y = var_1301_promoted)[name = tensor("op_1302")]; tensor shape_19 = const()[name = tensor("shape_19"), val = tensor([9, 50, 15])]; tensor reshape_21 = const()[name = tensor("reshape_21"), val = tensor([6000, 6001, 6002, 6003, 6004, 6005, 6006, 6007, 6008, 6009, 6010, 6011, 6012, 6013, 6014, 6015, 6016, 6017, 6018, 6019, 6020, 6021, 6022, 6023, 6024, 6025, 6026, 6027, 6028, 6029, 6030, 6031, 6032, 6033, 6034, 6035, 6036, 6037, 6038, 6039, 6040, 6041, 6042, 6043, 6044, 6045, 6046, 6047, 6048, 6049, 6050, 6051, 6052, 6053, 6054, 6055, 6056, 6057, 6058, 6059, 6060, 6061, 6062, 6063, 6064, 6065, 6066, 6067, 6068, 6069, 6070, 6071, 6072, 6073, 6074, 6075, 6076, 6077, 6078, 6079, 6080, 6081, 6082, 6083, 6084, 6085, 6086, 6087, 6088, 6089, 6090, 6091, 6092, 6093, 6094, 6095, 6096, 6097, 6098, 6099, 6100, 6101, 6102, 6103, 6104, 6105, 6106, 6107, 6108, 6109, 6110, 6111, 6112, 6113, 6114, 6115, 6116, 6117, 6118, 6119, 6120, 6121, 6122, 6123, 6124, 6125, 6126, 6127, 6128, 6129, 6130, 6131, 6132, 6133, 6134, 6135, 6136, 6137, 6138, 6139, 6140, 6141, 6142, 6143, 6144, 6145, 6146, 6147, 6148, 6149, 6150, 6151, 6152, 6153, 6154, 6155, 6156, 6157, 6158, 6159, 6160, 6161, 6162, 6163, 6164, 6165, 6166, 6167, 6168, 6169, 6170, 6171, 6172, 6173, 6174, 6175, 6176, 6177, 6178, 6179, 6180, 6181, 6182, 6183, 6184, 6185, 6186, 6187, 6188, 6189, 6190, 6191, 6192, 6193, 6194, 6195, 6196, 6197, 6198, 6199, 6200, 6201, 6202, 6203, 6204, 6205, 6206, 6207, 6208, 6209, 6210, 6211, 6212, 6213, 6214, 6215, 6216, 6217, 6218, 6219, 6220, 6221, 6222, 6223, 6224, 6225, 6226, 6227, 6228, 6229, 6230, 6231, 6232, 6233, 6234, 6235, 6236, 6237, 6238, 6239, 6240, 6241, 6242, 6243, 6244, 6245, 6246, 6247, 6248, 6249, 6250, 6251, 6252, 6253, 6254, 6255, 6256, 6257, 6258, 6259, 6260, 6261, 6262, 6263, 6264, 6265, 6266, 6267, 6268, 6269, 6270, 6271, 6272, 6273, 6274, 6275, 6276, 6277, 6278, 6279, 6280, 6281, 6282, 6283, 6284, 6285, 6286, 6287, 6288, 6289, 6290, 6291, 6292, 6293, 6294, 6295, 6296, 6297, 6298, 6299, 6300, 6301, 6302, 6303, 6304, 6305, 6306, 6307, 6308, 6309, 6310, 6311, 6312, 6313, 6314, 6315, 6316, 6317, 6318, 6319, 6320, 6321, 6322, 6323, 6324, 6325, 6326, 6327, 6328, 6329, 6330, 6331, 6332, 6333, 6334, 6335, 6336, 6337, 6338, 6339, 6340, 6341, 6342, 6343, 6344, 6345, 6346, 6347, 6348, 6349, 6350, 6351, 6352, 6353, 6354, 6355, 6356, 6357, 6358, 6359, 6360, 6361, 6362, 6363, 6364, 6365, 6366, 6367, 6368, 6369, 6370, 6371, 6372, 6373, 6374, 6375, 6376, 6377, 6378, 6379, 6380, 6381, 6382, 6383, 6384, 6385, 6386, 6387, 6388, 6389, 6390, 6391, 6392, 6393, 6394, 6395, 6396, 6397, 6398, 6399, 6400, 6401, 6402, 6403, 6404, 6405, 6406, 6407, 6408, 6409, 6410, 6411, 6412, 6413, 6414, 6415, 6416, 6417, 6418, 6419, 6420, 6421, 6422, 6423, 6424, 6425, 6426, 6427, 6428, 6429, 6430, 6431, 6432, 6433, 6434, 6435, 6436, 6437, 6438, 6439, 6440, 6441, 6442, 6443, 6444, 6445, 6446, 6447, 6448, 6449, 6450, 6451, 6452, 6453, 6454, 6455, 6456, 6457, 6458, 6459, 6460, 6461, 6462, 6463, 6464, 6465, 6466, 6467, 6468, 6469, 6470, 6471, 6472, 6473, 6474, 6475, 6476, 6477, 6478, 6479, 6480, 6481, 6482, 6483, 6484, 6485, 6486, 6487, 6488, 6489, 6490, 6491, 6492, 6493, 6494, 6495, 6496, 6497, 6498, 6499, 6500, 6501, 6502, 6503, 6504, 6505, 6506, 6507, 6508, 6509, 6510, 6511, 6512, 6513, 6514, 6515, 6516, 6517, 6518, 6519, 6520, 6521, 6522, 6523, 6524, 6525, 6526, 6527, 6528, 6529, 6530, 6531, 6532, 6533, 6534, 6535, 6536, 6537, 6538, 6539, 6540, 6541, 6542, 6543, 6544, 6545, 6546, 6547, 6548, 6549, 6550, 6551, 6552, 6553, 6554, 6555, 6556, 6557, 6558, 6559, 6560, 6561, 6562, 6563, 6564, 6565, 6566, 6567, 6568, 6569, 6570, 6571, 6572, 6573, 6574, 6575, 6576, 6577, 6578, 6579, 6580, 6581, 6582, 6583, 6584, 6585, 6586, 6587, 6588, 6589, 6590, 6591, 6592, 6593, 6594, 6595, 6596, 6597, 6598, 6599, 6600, 6601, 6602, 6603, 6604, 6605, 6606, 6607, 6608, 6609, 6610, 6611, 6612, 6613, 6614, 6615, 6616, 6617, 6618, 6619, 6620, 6621, 6622, 6623, 6624, 6625, 6626, 6627, 6628, 6629, 6630, 6631, 6632, 6633, 6634, 6635, 6636, 6637, 6638, 6639, 6640, 6641, 6642, 6643, 6644, 6645, 6646, 6647, 6648, 6649, 6650, 6651, 6652, 6653, 6654, 6655, 6656, 6657, 6658, 6659, 6660, 6661, 6662, 6663, 6664, 6665, 6666, 6667, 6668, 6669, 6670, 6671, 6672, 6673, 6674, 6675, 6676, 6677, 6678, 6679, 6680, 6681, 6682, 6683, 6684, 6685, 6686, 6687, 6688, 6689, 6690, 6691, 6692, 6693, 6694, 6695, 6696, 6697, 6698, 6699, 6700, 6701, 6702, 6703, 6704, 6705, 6706, 6707, 6708, 6709, 6710, 6711, 6712, 6713, 6714, 6715, 6716, 6717, 6718, 6719, 6720, 6721, 6722, 6723, 6724, 6725, 6726, 6727, 6728, 6729, 6730, 6731, 6732, 6733, 6734, 6735, 6736, 6737, 6738, 6739, 6740, 6741, 6742, 6743, 6744, 6745, 6746, 6747, 6748, 6749])]; tensor reshape_22_shape_0 = const()[name = tensor("reshape_22_shape_0"), val = tensor([-1])]; tensor reshape_22 = reshape(shape = reshape_22_shape_0, x = var_1302)[name = tensor("reshape_22")]; tensor reshape_23_shape_0 = const()[name = tensor("reshape_23_shape_0"), val = tensor([-1])]; tensor reshape_23 = reshape(shape = reshape_23_shape_0, x = reshape_19)[name = tensor("reshape_23")]; tensor scatter_4_mode_0 = const()[name = tensor("scatter_4_mode_0"), val = tensor("update")]; tensor scatter_4_axis_0 = const()[name = tensor("scatter_4_axis_0"), val = tensor(0)]; tensor scatter_4 = scatter(axis = scatter_4_axis_0, data = reshape_23, indices = reshape_21, mode = scatter_4_mode_0, updates = reshape_22)[name = tensor("scatter_4")]; tensor reshape_24 = reshape(shape = shape_19, x = scatter_4)[name = tensor("reshape_24")]; tensor search_ranks_begin_0 = const()[name = tensor("search_ranks_begin_0"), val = tensor([7, 0, 0])]; tensor search_ranks_end_0 = const()[name = tensor("search_ranks_end_0"), val = tensor([8, 50, 15])]; tensor search_ranks_end_mask_0 = const()[name = tensor("search_ranks_end_mask_0"), val = tensor([false, true, true])]; tensor search_ranks_squeeze_mask_0 = const()[name = tensor("search_ranks_squeeze_mask_0"), val = tensor([true, false, false])]; tensor search_ranks = slice_by_index(begin = search_ranks_begin_0, end = search_ranks_end_0, end_mask = search_ranks_end_mask_0, squeeze_mask = search_ranks_squeeze_mask_0, x = reshape_24)[name = tensor("search_ranks")]; tensor risk_level_only_begin_0 = const()[name = tensor("risk_level_only_begin_0"), val = tensor([0, 0])]; tensor risk_level_only_end_0 = const()[name = tensor("risk_level_only_end_0"), val = tensor([50, 1])]; tensor risk_level_only_end_mask_0 = const()[name = tensor("risk_level_only_end_mask_0"), val = tensor([true, false])]; tensor risk_level_only = slice_by_index(begin = risk_level_only_begin_0, end = risk_level_only_end_0, end_mask = risk_level_only_end_mask_0, x = riskLevel)[name = tensor("risk_level_only")]; tensor ignore_level_max = const()[name = tensor("ignore_level_max"), val = tensor([0x1.8p+1])]; tensor var_1322 = greater(x = risk_level_only, y = ignore_level_max)[name = tensor("op_1322")]; tensor var_1323 = const()[name = tensor("op_1323"), val = tensor([0x1.4p+2])]; tensor var_1324 = less(x = risk_level_only, y = var_1323)[name = tensor("op_1324")]; tensor var_1325 = logical_and(x = var_1322, y = var_1324)[name = tensor("op_1325")]; tensor var_1325_promoted_dtype_0 = const()[name = tensor("op_1325_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1330 = const()[name = tensor("op_1330"), val = tensor(0x1.028f5cp+0)]; tensor var_1331 = add(x = search_ranks, y = var_1330)[name = tensor("op_1331")]; tensor var_1332_epsilon_0 = const()[name = tensor("op_1332_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_1332 = inverse(epsilon = var_1332_epsilon_0, x = var_1331)[name = tensor("op_1332")]; tensor var_1325_promoted = cast(dtype = var_1325_promoted_dtype_0, x = var_1325)[name = tensor("cast_200")]; tensor var_1335 = mul(x = var_1332, y = var_1325_promoted)[name = tensor("op_1335")]; tensor var_1336 = const()[name = tensor("op_1336"), val = tensor(-0x1p-1)]; tensor var_1337 = greater(x = search_ranks, y = var_1336)[name = tensor("op_1337")]; tensor var_1337_promoted_dtype_0 = const()[name = tensor("op_1337_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1337_promoted = cast(dtype = var_1337_promoted_dtype_0, x = var_1337)[name = tensor("cast_199")]; tensor search_ranks_adjusted = mul(x = var_1335, y = var_1337_promoted)[name = tensor("search_ranks_adjusted")]; tensor shape_20 = const()[name = tensor("shape_20"), val = tensor([9, 50, 15])]; tensor reshape_26 = const()[name = tensor("reshape_26"), val = tensor([5250, 5251, 5252, 5253, 5254, 5255, 5256, 5257, 5258, 5259, 5260, 5261, 5262, 5263, 5264, 5265, 5266, 5267, 5268, 5269, 5270, 5271, 5272, 5273, 5274, 5275, 5276, 5277, 5278, 5279, 5280, 5281, 5282, 5283, 5284, 5285, 5286, 5287, 5288, 5289, 5290, 5291, 5292, 5293, 5294, 5295, 5296, 5297, 5298, 5299, 5300, 5301, 5302, 5303, 5304, 5305, 5306, 5307, 5308, 5309, 5310, 5311, 5312, 5313, 5314, 5315, 5316, 5317, 5318, 5319, 5320, 5321, 5322, 5323, 5324, 5325, 5326, 5327, 5328, 5329, 5330, 5331, 5332, 5333, 5334, 5335, 5336, 5337, 5338, 5339, 5340, 5341, 5342, 5343, 5344, 5345, 5346, 5347, 5348, 5349, 5350, 5351, 5352, 5353, 5354, 5355, 5356, 5357, 5358, 5359, 5360, 5361, 5362, 5363, 5364, 5365, 5366, 5367, 5368, 5369, 5370, 5371, 5372, 5373, 5374, 5375, 5376, 5377, 5378, 5379, 5380, 5381, 5382, 5383, 5384, 5385, 5386, 5387, 5388, 5389, 5390, 5391, 5392, 5393, 5394, 5395, 5396, 5397, 5398, 5399, 5400, 5401, 5402, 5403, 5404, 5405, 5406, 5407, 5408, 5409, 5410, 5411, 5412, 5413, 5414, 5415, 5416, 5417, 5418, 5419, 5420, 5421, 5422, 5423, 5424, 5425, 5426, 5427, 5428, 5429, 5430, 5431, 5432, 5433, 5434, 5435, 5436, 5437, 5438, 5439, 5440, 5441, 5442, 5443, 5444, 5445, 5446, 5447, 5448, 5449, 5450, 5451, 5452, 5453, 5454, 5455, 5456, 5457, 5458, 5459, 5460, 5461, 5462, 5463, 5464, 5465, 5466, 5467, 5468, 5469, 5470, 5471, 5472, 5473, 5474, 5475, 5476, 5477, 5478, 5479, 5480, 5481, 5482, 5483, 5484, 5485, 5486, 5487, 5488, 5489, 5490, 5491, 5492, 5493, 5494, 5495, 5496, 5497, 5498, 5499, 5500, 5501, 5502, 5503, 5504, 5505, 5506, 5507, 5508, 5509, 5510, 5511, 5512, 5513, 5514, 5515, 5516, 5517, 5518, 5519, 5520, 5521, 5522, 5523, 5524, 5525, 5526, 5527, 5528, 5529, 5530, 5531, 5532, 5533, 5534, 5535, 5536, 5537, 5538, 5539, 5540, 5541, 5542, 5543, 5544, 5545, 5546, 5547, 5548, 5549, 5550, 5551, 5552, 5553, 5554, 5555, 5556, 5557, 5558, 5559, 5560, 5561, 5562, 5563, 5564, 5565, 5566, 5567, 5568, 5569, 5570, 5571, 5572, 5573, 5574, 5575, 5576, 5577, 5578, 5579, 5580, 5581, 5582, 5583, 5584, 5585, 5586, 5587, 5588, 5589, 5590, 5591, 5592, 5593, 5594, 5595, 5596, 5597, 5598, 5599, 5600, 5601, 5602, 5603, 5604, 5605, 5606, 5607, 5608, 5609, 5610, 5611, 5612, 5613, 5614, 5615, 5616, 5617, 5618, 5619, 5620, 5621, 5622, 5623, 5624, 5625, 5626, 5627, 5628, 5629, 5630, 5631, 5632, 5633, 5634, 5635, 5636, 5637, 5638, 5639, 5640, 5641, 5642, 5643, 5644, 5645, 5646, 5647, 5648, 5649, 5650, 5651, 5652, 5653, 5654, 5655, 5656, 5657, 5658, 5659, 5660, 5661, 5662, 5663, 5664, 5665, 5666, 5667, 5668, 5669, 5670, 5671, 5672, 5673, 5674, 5675, 5676, 5677, 5678, 5679, 5680, 5681, 5682, 5683, 5684, 5685, 5686, 5687, 5688, 5689, 5690, 5691, 5692, 5693, 5694, 5695, 5696, 5697, 5698, 5699, 5700, 5701, 5702, 5703, 5704, 5705, 5706, 5707, 5708, 5709, 5710, 5711, 5712, 5713, 5714, 5715, 5716, 5717, 5718, 5719, 5720, 5721, 5722, 5723, 5724, 5725, 5726, 5727, 5728, 5729, 5730, 5731, 5732, 5733, 5734, 5735, 5736, 5737, 5738, 5739, 5740, 5741, 5742, 5743, 5744, 5745, 5746, 5747, 5748, 5749, 5750, 5751, 5752, 5753, 5754, 5755, 5756, 5757, 5758, 5759, 5760, 5761, 5762, 5763, 5764, 5765, 5766, 5767, 5768, 5769, 5770, 5771, 5772, 5773, 5774, 5775, 5776, 5777, 5778, 5779, 5780, 5781, 5782, 5783, 5784, 5785, 5786, 5787, 5788, 5789, 5790, 5791, 5792, 5793, 5794, 5795, 5796, 5797, 5798, 5799, 5800, 5801, 5802, 5803, 5804, 5805, 5806, 5807, 5808, 5809, 5810, 5811, 5812, 5813, 5814, 5815, 5816, 5817, 5818, 5819, 5820, 5821, 5822, 5823, 5824, 5825, 5826, 5827, 5828, 5829, 5830, 5831, 5832, 5833, 5834, 5835, 5836, 5837, 5838, 5839, 5840, 5841, 5842, 5843, 5844, 5845, 5846, 5847, 5848, 5849, 5850, 5851, 5852, 5853, 5854, 5855, 5856, 5857, 5858, 5859, 5860, 5861, 5862, 5863, 5864, 5865, 5866, 5867, 5868, 5869, 5870, 5871, 5872, 5873, 5874, 5875, 5876, 5877, 5878, 5879, 5880, 5881, 5882, 5883, 5884, 5885, 5886, 5887, 5888, 5889, 5890, 5891, 5892, 5893, 5894, 5895, 5896, 5897, 5898, 5899, 5900, 5901, 5902, 5903, 5904, 5905, 5906, 5907, 5908, 5909, 5910, 5911, 5912, 5913, 5914, 5915, 5916, 5917, 5918, 5919, 5920, 5921, 5922, 5923, 5924, 5925, 5926, 5927, 5928, 5929, 5930, 5931, 5932, 5933, 5934, 5935, 5936, 5937, 5938, 5939, 5940, 5941, 5942, 5943, 5944, 5945, 5946, 5947, 5948, 5949, 5950, 5951, 5952, 5953, 5954, 5955, 5956, 5957, 5958, 5959, 5960, 5961, 5962, 5963, 5964, 5965, 5966, 5967, 5968, 5969, 5970, 5971, 5972, 5973, 5974, 5975, 5976, 5977, 5978, 5979, 5980, 5981, 5982, 5983, 5984, 5985, 5986, 5987, 5988, 5989, 5990, 5991, 5992, 5993, 5994, 5995, 5996, 5997, 5998, 5999])]; tensor reshape_27_shape_0 = const()[name = tensor("reshape_27_shape_0"), val = tensor([-1])]; tensor reshape_27 = reshape(shape = reshape_27_shape_0, x = search_ranks_adjusted)[name = tensor("reshape_27")]; tensor reshape_28_shape_0 = const()[name = tensor("reshape_28_shape_0"), val = tensor([-1])]; tensor reshape_28 = reshape(shape = reshape_28_shape_0, x = reshape_24)[name = tensor("reshape_28")]; tensor scatter_5_mode_0 = const()[name = tensor("scatter_5_mode_0"), val = tensor("update")]; tensor scatter_5_axis_0 = const()[name = tensor("scatter_5_axis_0"), val = tensor(0)]; tensor scatter_5 = scatter(axis = scatter_5_axis_0, data = reshape_28, indices = reshape_26, mode = scatter_5_mode_0, updates = reshape_27)[name = tensor("scatter_5")]; tensor reshape_29 = reshape(shape = shape_20, x = scatter_5)[name = tensor("reshape_29")]; tensor var_1344 = const()[name = tensor("op_1344"), val = tensor(-0x1.99999ap-4)]; tensor var_1345 = greater(x = reshape_29, y = var_1344)[name = tensor("op_1345")]; tensor var_1345_promoted_dtype_0 = const()[name = tensor("op_1345_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1345_promoted = cast(dtype = var_1345_promoted_dtype_0, x = var_1345)[name = tensor("cast_198")]; tensor var_1346 = mul(x = reshape_29, y = var_1345_promoted)[name = tensor("op_1346")]; tensor t_axes_0 = const()[name = tensor("t_axes_0"), val = tensor([0])]; tensor t_keep_dims_0 = const()[name = tensor("t_keep_dims_0"), val = tensor(false)]; tensor t = reduce_sum(axes = t_axes_0, keep_dims = t_keep_dims_0, x = var_1346)[name = tensor("t")]; tensor var_1358_axes_0 = const()[name = tensor("op_1358_axes_0"), val = tensor([0])]; tensor var_1358_keep_dims_0 = const()[name = tensor("op_1358_keep_dims_0"), val = tensor(false)]; tensor var_1358 = reduce_sum(axes = var_1358_axes_0, keep_dims = var_1358_keep_dims_0, x = t)[name = tensor("op_1358")]; tensor var_1360 = const()[name = tensor("op_1360"), val = tensor(0x1.a36e2ep-14)]; tensor sums = add(x = var_1358, y = var_1360)[name = tensor("sums")]; tensor var_1364 = real_div(x = t, y = sums)[name = tensor("op_1364")]; tensor var_1365 = const()[name = tensor("op_1365"), val = tensor(0x1.a08aep+3)]; tensor var_1366 = mul(x = var_1364, y = var_1365)[name = tensor("op_1366")]; tensor x_5_perm_0 = const()[name = tensor("x_5_perm_0"), val = tensor([1, 0])]; tensor var_1372_begin_0 = const()[name = tensor("op_1372_begin_0"), val = tensor([0, 0])]; tensor var_1372_end_0 = const()[name = tensor("op_1372_end_0"), val = tensor([1, 1000])]; tensor var_1372_end_mask_0 = const()[name = tensor("op_1372_end_mask_0"), val = tensor([false, true])]; tensor var_1372_squeeze_mask_0 = const()[name = tensor("op_1372_squeeze_mask_0"), val = tensor([true, false])]; tensor x_5 = transpose(perm = x_5_perm_0, x = tupleInteractions_candidates)[name = tensor("transpose_34")]; tensor var_1372 = slice_by_index(begin = var_1372_begin_0, end = var_1372_end_0, end_mask = var_1372_end_mask_0, squeeze_mask = var_1372_squeeze_mask_0, x = x_5)[name = tensor("op_1372")]; tensor var_1373 = const()[name = tensor("op_1373"), val = tensor(-0x1.e848p+19)]; tensor var_1374 = greater(x = var_1372, y = var_1373)[name = tensor("op_1374")]; tensor var_1374_promoted_dtype_0 = const()[name = tensor("op_1374_promoted_dtype_0"), val = tensor("int32")]; tensor var_1377 = const()[name = tensor("op_1377"), val = tensor(-0x1.e848p+19)]; tensor ones_3_promoted_dtype_0 = const()[name = tensor("ones_3_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1374_to_fp32 = cast(dtype = ones_3_promoted_dtype_0, x = var_1374)[name = tensor("cast_196")]; tensor small_3 = mul(x = var_1374_to_fp32, y = var_1377)[name = tensor("small_3")]; tensor var_1379 = const()[name = tensor("op_1379"), val = tensor(0x1.e848p+19)]; tensor big_3 = mul(x = var_1374_to_fp32, y = var_1379)[name = tensor("big_3")]; tensor var_1381 = const()[name = tensor("op_1381"), val = tensor(0)]; tensor var_1374_promoted = cast(dtype = var_1374_promoted_dtype_0, x = var_1374)[name = tensor("cast_197")]; tensor zeros_3 = mul(x = var_1374_promoted, y = var_1381)[name = tensor("zeros_3")]; tensor var_1384_axes_0 = const()[name = tensor("op_1384_axes_0"), val = tensor([0])]; tensor var_1384 = expand_dims(axes = var_1384_axes_0, x = small_3)[name = tensor("op_1384")]; tensor var_1386_axes_0 = const()[name = tensor("op_1386_axes_0"), val = tensor([0])]; tensor var_1386 = expand_dims(axes = var_1386_axes_0, x = big_3)[name = tensor("op_1386")]; tensor var_1388 = const()[name = tensor("op_1388"), val = tensor(0)]; tensor x_padded_3_interleave_0 = const()[name = tensor("x_padded_3_interleave_0"), val = tensor(false)]; tensor x_padded_3 = concat(axis = var_1388, interleave = x_padded_3_interleave_0, values = (var_1384, x_5, var_1386))[name = tensor("x_padded_3")]; tensor var_1390 = const()[name = tensor("op_1390"), val = tensor(0)]; tensor logical_not_2 = const()[name = tensor("logical_not_2"), val = tensor(true)]; tensor i_3 = argsort(ascending = logical_not_2, axis = var_1390, x = x_padded_3)[name = tensor("i_3")]; tensor by_x_3 = gather_along_axis(axis = var_1390, indices = i_3, x = x_padded_3)[name = tensor("by_x_3")]; tensor var_1398_begin_0 = const()[name = tensor("op_1398_begin_0"), val = tensor([1, 0])]; tensor var_1398_end_0 = const()[name = tensor("op_1398_end_0"), val = tensor([3, 1000])]; tensor var_1398_end_mask_0 = const()[name = tensor("op_1398_end_mask_0"), val = tensor([false, true])]; tensor var_1398 = slice_by_index(begin = var_1398_begin_0, end = var_1398_end_0, end_mask = var_1398_end_mask_0, x = by_x_3)[name = tensor("op_1398")]; tensor var_1403_begin_0 = const()[name = tensor("op_1403_begin_0"), val = tensor([0, 0])]; tensor var_1403_end_0 = const()[name = tensor("op_1403_end_0"), val = tensor([2, 1000])]; tensor var_1403_end_mask_0 = const()[name = tensor("op_1403_end_mask_0"), val = tensor([false, true])]; tensor var_1403 = slice_by_index(begin = var_1403_begin_0, end = var_1403_end_0, end_mask = var_1403_end_mask_0, x = by_x_3)[name = tensor("op_1403")]; tensor var_1405 = sub(x = var_1398, y = var_1403)[name = tensor("op_1405")]; tensor var_1406_promoted = const()[name = tensor("op_1406_promoted"), val = tensor(0x0p+0)]; tensor var_1407 = greater(x = var_1405, y = var_1406_promoted)[name = tensor("op_1407")]; tensor var_1407_promoted_dtype_0 = const()[name = tensor("op_1407_promoted_dtype_0"), val = tensor("int32")]; tensor var_1411_axes_0 = const()[name = tensor("op_1411_axes_0"), val = tensor([0])]; tensor var_1411 = expand_dims(axes = var_1411_axes_0, x = zeros_3)[name = tensor("op_1411")]; tensor var_1415 = const()[name = tensor("op_1415"), val = tensor(0)]; tensor mask_13_interleave_0 = const()[name = tensor("mask_13_interleave_0"), val = tensor(false)]; tensor var_1407_promoted = cast(dtype = var_1407_promoted_dtype_0, x = var_1407)[name = tensor("cast_195")]; tensor mask_13 = concat(axis = var_1415, interleave = mask_13_interleave_0, values = (var_1411, var_1407_promoted, var_1411))[name = tensor("mask_13")]; tensor mask_13_promoted_dtype_0 = const()[name = tensor("mask_13_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_13_promoted = cast(dtype = mask_13_promoted_dtype_0, x = mask_13)[name = tensor("cast_194")]; tensor var_1417 = mul(x = by_x_3, y = mask_13_promoted)[name = tensor("op_1417")]; tensor var_1418 = const()[name = tensor("op_1418"), val = tensor(0)]; tensor logical_not_3 = const()[name = tensor("logical_not_3"), val = tensor(true)]; tensor var_1420 = argsort(ascending = logical_not_3, axis = var_1418, x = i_3)[name = tensor("op_1420")]; tensor var_1421 = const()[name = tensor("op_1421"), val = tensor(0)]; tensor unique_3 = gather_along_axis(axis = var_1421, indices = var_1420, x = var_1417)[name = tensor("unique_3")]; tensor unique_ti_1_begin_0 = const()[name = tensor("unique_ti_1_begin_0"), val = tensor([1, 0])]; tensor unique_ti_1_end_0 = const()[name = tensor("unique_ti_1_end_0"), val = tensor([3, 1000])]; tensor unique_ti_1_end_mask_0 = const()[name = tensor("unique_ti_1_end_mask_0"), val = tensor([false, true])]; tensor unique_ti_1 = slice_by_index(begin = unique_ti_1_begin_0, end = unique_ti_1_end_0, end_mask = unique_ti_1_end_mask_0, x = unique_3)[name = tensor("unique_ti_1")]; tensor unique_ti_3_perm_0 = const()[name = tensor("unique_ti_3_perm_0"), val = tensor([1, 0])]; tensor var_1432 = const()[name = tensor("op_1432"), val = tensor(0x1p-1)]; tensor unique_ti_3 = transpose(perm = unique_ti_3_perm_0, x = unique_ti_1)[name = tensor("transpose_33")]; tensor var_1433 = greater(x = unique_ti_3, y = var_1432)[name = tensor("op_1433")]; tensor var_1433_promoted_dtype_0 = const()[name = tensor("op_1433_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1433_promoted = cast(dtype = var_1433_promoted_dtype_0, x = var_1433)[name = tensor("cast_193")]; tensor var_1437 = mul(x = unique_ti_3, y = var_1433_promoted)[name = tensor("op_1437")]; tensor var_1438 = const()[name = tensor("op_1438"), val = tensor(0x1p+0)]; tensor var_1440 = sub(x = var_1438, y = var_1433_promoted)[name = tensor("op_1440")]; tensor var_1441_promoted = const()[name = tensor("op_1441_promoted"), val = tensor(-0x1p+0)]; tensor var_1442 = mul(x = var_1440, y = var_1441_promoted)[name = tensor("op_1442")]; tensor unique_ti = add(x = var_1437, y = var_1442)[name = tensor("unique_ti")]; tensor var_1446_axes_0 = const()[name = tensor("op_1446_axes_0"), val = tensor([0])]; tensor var_1446 = expand_dims(axes = var_1446_axes_0, x = unique_ti)[name = tensor("op_1446")]; tensor inflated_history_axes_0 = const()[name = tensor("inflated_history_axes_0"), val = tensor([0])]; tensor inflated_history = expand_dims(axes = inflated_history_axes_0, x = var_1446)[name = tensor("inflated_history")]; tensor var_1450_axes_0 = const()[name = tensor("op_1450_axes_0"), val = tensor([-1])]; tensor var_1450 = expand_dims(axes = var_1450_axes_0, x = x_3)[name = tensor("op_1450")]; tensor inflated_tuples_axes_0 = const()[name = tensor("inflated_tuples_axes_0"), val = tensor([-1])]; tensor inflated_tuples = expand_dims(axes = inflated_tuples_axes_0, x = var_1450)[name = tensor("inflated_tuples")]; tensor var_1454 = sub(x = inflated_history, y = inflated_tuples)[name = tensor("op_1454")]; tensor var_1455 = abs(x = var_1454)[name = tensor("op_1455")]; tensor var_1456 = const()[name = tensor("op_1456"), val = tensor(0x1.0624dep-10)]; tensor candidate_compare = less(x = var_1455, y = var_1456)[name = tensor("candidate_compare")]; tensor cast_41_dtype_0 = const()[name = tensor("cast_41_dtype_0"), val = tensor("fp32")]; tensor per_tuple_match_axes_0 = const()[name = tensor("per_tuple_match_axes_0"), val = tensor([3])]; tensor per_tuple_match_keep_dims_0 = const()[name = tensor("per_tuple_match_keep_dims_0"), val = tensor(false)]; tensor cast_41 = cast(dtype = cast_41_dtype_0, x = candidate_compare)[name = tensor("cast_192")]; tensor per_tuple_match = reduce_sum(axes = per_tuple_match_axes_0, keep_dims = per_tuple_match_keep_dims_0, x = cast_41)[name = tensor("per_tuple_match")]; tensor matching_counts_axes_0 = const()[name = tensor("matching_counts_axes_0"), val = tensor([1])]; tensor matching_counts_keep_dims_0 = const()[name = tensor("matching_counts_keep_dims_0"), val = tensor(false)]; tensor matching_counts = reduce_sum(axes = matching_counts_axes_0, keep_dims = matching_counts_keep_dims_0, x = per_tuple_match)[name = tensor("matching_counts")]; tensor var_1469_promoted = const()[name = tensor("op_1469_promoted"), val = tensor(0x1p+1)]; tensor var_1470 = sub(x = matching_counts, y = var_1469_promoted)[name = tensor("op_1470")]; tensor var_1471 = abs(x = var_1470)[name = tensor("op_1471")]; tensor var_1472 = const()[name = tensor("op_1472"), val = tensor(0x1.0624dep-10)]; tensor var_1473 = less(x = var_1471, y = var_1472)[name = tensor("op_1473")]; tensor var_1473_promoted_dtype_0 = const()[name = tensor("op_1473_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1478_axes_0 = const()[name = tensor("op_1478_axes_0"), val = tensor([1])]; tensor var_1473_promoted = cast(dtype = var_1473_promoted_dtype_0, x = var_1473)[name = tensor("cast_191")]; tensor var_1478 = expand_dims(axes = var_1478_axes_0, x = var_1473_promoted)[name = tensor("op_1478")]; tensor x_7 = mul(x = per_tuple_match, y = var_1478)[name = tensor("x_7")]; tensor var_1487_promoted = const()[name = tensor("op_1487_promoted"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(24512)))]; tensor numbered = mul(x = x_7, y = var_1487_promoted)[name = tensor("numbered")]; tensor var_1489 = const()[name = tensor("op_1489"), val = tensor(0x1p-1)]; tensor var_1490 = less(x = numbered, y = var_1489)[name = tensor("op_1490")]; tensor var_1491 = const()[name = tensor("op_1491"), val = tensor(100)]; tensor var_1490_promoted_dtype_0 = const()[name = tensor("op_1490_promoted_dtype_0"), val = tensor("int32")]; tensor var_1490_promoted = cast(dtype = var_1490_promoted_dtype_0, x = var_1490)[name = tensor("cast_190")]; tensor zero_mask = mul(x = var_1490_promoted, y = var_1491)[name = tensor("zero_mask")]; tensor zero_mask_promoted_dtype_0 = const()[name = tensor("zero_mask_promoted_dtype_0"), val = tensor("fp32")]; tensor zero_mask_promoted = cast(dtype = zero_mask_promoted_dtype_0, x = zero_mask)[name = tensor("cast_189")]; tensor var_1494 = add(x = numbered, y = zero_mask_promoted)[name = tensor("op_1494")]; tensor reduce_min_0_axes_0 = const()[name = tensor("reduce_min_0_axes_0"), val = tensor([1])]; tensor reduce_min_0_keep_dims_0 = const()[name = tensor("reduce_min_0_keep_dims_0"), val = tensor(false)]; tensor reduce_min_0 = reduce_min(axes = reduce_min_0_axes_0, keep_dims = reduce_min_0_keep_dims_0, x = var_1494)[name = tensor("reduce_min_0")]; tensor var_1500_promoted = const()[name = tensor("op_1500_promoted"), val = tensor(0x1.9p+6)]; tensor var_1501 = sub(x = reduce_min_0, y = var_1500_promoted)[name = tensor("op_1501")]; tensor var_1502 = abs(x = var_1501)[name = tensor("op_1502")]; tensor var_1503 = const()[name = tensor("op_1503"), val = tensor(0x1.0624dep-10)]; tensor var_1504 = greater(x = var_1502, y = var_1503)[name = tensor("op_1504")]; tensor var_1504_promoted_dtype_0 = const()[name = tensor("op_1504_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1504_promoted = cast(dtype = var_1504_promoted_dtype_0, x = var_1504)[name = tensor("cast_188")]; tensor mins = mul(x = reduce_min_0, y = var_1504_promoted)[name = tensor("mins")]; tensor reduce_max_3_axes_0 = const()[name = tensor("reduce_max_3_axes_0"), val = tensor([1])]; tensor reduce_max_3_keep_dims_0 = const()[name = tensor("reduce_max_3_keep_dims_0"), val = tensor(false)]; tensor reduce_max_3 = reduce_max(axes = reduce_max_3_axes_0, keep_dims = reduce_max_3_keep_dims_0, x = numbered)[name = tensor("reduce_max_3")]; tensor var_1510_promoted = const()[name = tensor("op_1510_promoted"), val = tensor(0x1.ep+3)]; tensor var_1511 = mul(x = mins, y = var_1510_promoted)[name = tensor("op_1511")]; tensor pair_ids = add(x = var_1511, y = reduce_max_3)[name = tensor("pair_ids")]; tensor x_9_perm_0 = const()[name = tensor("x_9_perm_0"), val = tensor([1, 0])]; tensor var_1519_begin_0 = const()[name = tensor("op_1519_begin_0"), val = tensor([0, 0])]; tensor var_1519_end_0 = const()[name = tensor("op_1519_end_0"), val = tensor([1, 50])]; tensor var_1519_end_mask_0 = const()[name = tensor("op_1519_end_mask_0"), val = tensor([false, true])]; tensor var_1519_squeeze_mask_0 = const()[name = tensor("op_1519_squeeze_mask_0"), val = tensor([true, false])]; tensor x_9 = transpose(perm = x_9_perm_0, x = pair_ids)[name = tensor("transpose_32")]; tensor var_1519 = slice_by_index(begin = var_1519_begin_0, end = var_1519_end_0, end_mask = var_1519_end_mask_0, squeeze_mask = var_1519_squeeze_mask_0, x = x_9)[name = tensor("op_1519")]; tensor var_1520 = const()[name = tensor("op_1520"), val = tensor(-0x1.e848p+19)]; tensor var_1521 = greater(x = var_1519, y = var_1520)[name = tensor("op_1521")]; tensor var_1521_promoted_dtype_0 = const()[name = tensor("op_1521_promoted_dtype_0"), val = tensor("int32")]; tensor var_1524 = const()[name = tensor("op_1524"), val = tensor(-0x1.e848p+19)]; tensor ones_5_promoted_dtype_0 = const()[name = tensor("ones_5_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1521_to_fp32 = cast(dtype = ones_5_promoted_dtype_0, x = var_1521)[name = tensor("cast_186")]; tensor small_5 = mul(x = var_1521_to_fp32, y = var_1524)[name = tensor("small_5")]; tensor var_1526 = const()[name = tensor("op_1526"), val = tensor(0x1.e848p+19)]; tensor big_5 = mul(x = var_1521_to_fp32, y = var_1526)[name = tensor("big_5")]; tensor var_1528 = const()[name = tensor("op_1528"), val = tensor(0)]; tensor var_1521_promoted = cast(dtype = var_1521_promoted_dtype_0, x = var_1521)[name = tensor("cast_187")]; tensor zeros_5 = mul(x = var_1521_promoted, y = var_1528)[name = tensor("zeros_5")]; tensor var_1531_axes_0 = const()[name = tensor("op_1531_axes_0"), val = tensor([0])]; tensor var_1531 = expand_dims(axes = var_1531_axes_0, x = small_5)[name = tensor("op_1531")]; tensor var_1533_axes_0 = const()[name = tensor("op_1533_axes_0"), val = tensor([0])]; tensor var_1533 = expand_dims(axes = var_1533_axes_0, x = big_5)[name = tensor("op_1533")]; tensor var_1535 = const()[name = tensor("op_1535"), val = tensor(0)]; tensor x_padded_5_interleave_0 = const()[name = tensor("x_padded_5_interleave_0"), val = tensor(false)]; tensor x_padded_5 = concat(axis = var_1535, interleave = x_padded_5_interleave_0, values = (var_1531, x_9, var_1533))[name = tensor("x_padded_5")]; tensor var_1537 = const()[name = tensor("op_1537"), val = tensor(0)]; tensor logical_not_4 = const()[name = tensor("logical_not_4"), val = tensor(true)]; tensor i_5 = argsort(ascending = logical_not_4, axis = var_1537, x = x_padded_5)[name = tensor("i_5")]; tensor by_x_5 = gather_along_axis(axis = var_1537, indices = i_5, x = x_padded_5)[name = tensor("by_x_5")]; tensor var_1545_begin_0 = const()[name = tensor("op_1545_begin_0"), val = tensor([1, 0])]; tensor var_1545_end_0 = const()[name = tensor("op_1545_end_0"), val = tensor([1001, 50])]; tensor var_1545_end_mask_0 = const()[name = tensor("op_1545_end_mask_0"), val = tensor([false, true])]; tensor var_1545 = slice_by_index(begin = var_1545_begin_0, end = var_1545_end_0, end_mask = var_1545_end_mask_0, x = by_x_5)[name = tensor("op_1545")]; tensor var_1550_begin_0 = const()[name = tensor("op_1550_begin_0"), val = tensor([0, 0])]; tensor var_1550_end_0 = const()[name = tensor("op_1550_end_0"), val = tensor([1000, 50])]; tensor var_1550_end_mask_0 = const()[name = tensor("op_1550_end_mask_0"), val = tensor([false, true])]; tensor var_1550 = slice_by_index(begin = var_1550_begin_0, end = var_1550_end_0, end_mask = var_1550_end_mask_0, x = by_x_5)[name = tensor("op_1550")]; tensor var_1552 = sub(x = var_1545, y = var_1550)[name = tensor("op_1552")]; tensor var_1553_promoted = const()[name = tensor("op_1553_promoted"), val = tensor(0x0p+0)]; tensor var_1554 = greater(x = var_1552, y = var_1553_promoted)[name = tensor("op_1554")]; tensor var_1554_promoted_dtype_0 = const()[name = tensor("op_1554_promoted_dtype_0"), val = tensor("int32")]; tensor var_1558_axes_0 = const()[name = tensor("op_1558_axes_0"), val = tensor([0])]; tensor var_1558 = expand_dims(axes = var_1558_axes_0, x = zeros_5)[name = tensor("op_1558")]; tensor var_1562 = const()[name = tensor("op_1562"), val = tensor(0)]; tensor mask_17_interleave_0 = const()[name = tensor("mask_17_interleave_0"), val = tensor(false)]; tensor var_1554_promoted = cast(dtype = var_1554_promoted_dtype_0, x = var_1554)[name = tensor("cast_185")]; tensor mask_17 = concat(axis = var_1562, interleave = mask_17_interleave_0, values = (var_1558, var_1554_promoted, var_1558))[name = tensor("mask_17")]; tensor mask_17_promoted_dtype_0 = const()[name = tensor("mask_17_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_17_promoted = cast(dtype = mask_17_promoted_dtype_0, x = mask_17)[name = tensor("cast_184")]; tensor var_1564 = mul(x = by_x_5, y = mask_17_promoted)[name = tensor("op_1564")]; tensor var_1565 = const()[name = tensor("op_1565"), val = tensor(0)]; tensor logical_not_5 = const()[name = tensor("logical_not_5"), val = tensor(true)]; tensor var_1567 = argsort(ascending = logical_not_5, axis = var_1565, x = i_5)[name = tensor("op_1567")]; tensor var_1568 = const()[name = tensor("op_1568"), val = tensor(0)]; tensor unique_5 = gather_along_axis(axis = var_1568, indices = var_1567, x = var_1564)[name = tensor("unique_5")]; tensor unique_pairs_begin_0 = const()[name = tensor("unique_pairs_begin_0"), val = tensor([1, 0])]; tensor unique_pairs_end_0 = const()[name = tensor("unique_pairs_end_0"), val = tensor([1001, 50])]; tensor unique_pairs_end_mask_0 = const()[name = tensor("unique_pairs_end_mask_0"), val = tensor([false, true])]; tensor unique_pairs = slice_by_index(begin = unique_pairs_begin_0, end = unique_pairs_end_0, end_mask = unique_pairs_end_mask_0, x = unique_5)[name = tensor("unique_pairs")]; tensor var_1576 = const()[name = tensor("op_1576"), val = tensor(0x1p+0)]; tensor var_1577 = greater(x = unique_pairs, y = var_1576)[name = tensor("op_1577")]; tensor var_1577_promoted_dtype_0 = const()[name = tensor("op_1577_promoted_dtype_0"), val = tensor("fp32")]; tensor pair_coverage_axes_0 = const()[name = tensor("pair_coverage_axes_0"), val = tensor([0])]; tensor pair_coverage_keep_dims_0 = const()[name = tensor("pair_coverage_keep_dims_0"), val = tensor(false)]; tensor var_1577_promoted = cast(dtype = var_1577_promoted_dtype_0, x = var_1577)[name = tensor("cast_183")]; tensor pair_coverage = reduce_sum(axes = pair_coverage_axes_0, keep_dims = pair_coverage_keep_dims_0, x = var_1577_promoted)[name = tensor("pair_coverage")]; tensor var_1586 = const()[name = tensor("op_1586"), val = tensor(0x0p+0)]; tensor var_1587 = greater(x = x_3, y = var_1586)[name = tensor("op_1587")]; tensor cast_42_dtype_0 = const()[name = tensor("cast_42_dtype_0"), val = tensor("fp32")]; tensor optimal_pair_coverage_axes_0 = const()[name = tensor("optimal_pair_coverage_axes_0"), val = tensor([1])]; tensor optimal_pair_coverage_keep_dims_0 = const()[name = tensor("optimal_pair_coverage_keep_dims_0"), val = tensor(false)]; tensor cast_42 = cast(dtype = cast_42_dtype_0, x = var_1587)[name = tensor("cast_182")]; tensor optimal_pair_coverage = reduce_sum(axes = optimal_pair_coverage_axes_0, keep_dims = optimal_pair_coverage_keep_dims_0, x = cast_42)[name = tensor("optimal_pair_coverage")]; tensor coverage = sub(x = optimal_pair_coverage, y = pair_coverage)[name = tensor("coverage")]; tensor var_1597_perm_0 = const()[name = tensor("op_1597_perm_0"), val = tensor([1, 0])]; tensor var_1602_begin_0 = const()[name = tensor("op_1602_begin_0"), val = tensor([0, 0])]; tensor var_1602_end_0 = const()[name = tensor("op_1602_end_0"), val = tensor([1, 1000])]; tensor var_1602_end_mask_0 = const()[name = tensor("op_1602_end_mask_0"), val = tensor([false, true])]; tensor var_1597 = transpose(perm = var_1597_perm_0, x = tuple_alignment)[name = tensor("transpose_31")]; tensor var_1602 = slice_by_index(begin = var_1602_begin_0, end = var_1602_end_0, end_mask = var_1602_end_mask_0, x = var_1597)[name = tensor("op_1602")]; tensor alignments_3 = squeeze(x = var_1602)[name = tensor("alignments_3")]; tensor var_1609_begin_0 = const()[name = tensor("op_1609_begin_0"), val = tensor([0, 0])]; tensor var_1609_end_0 = const()[name = tensor("op_1609_end_0"), val = tensor([1, 1000])]; tensor var_1609_end_mask_0 = const()[name = tensor("op_1609_end_mask_0"), val = tensor([false, true])]; tensor var_1609_squeeze_mask_0 = const()[name = tensor("op_1609_squeeze_mask_0"), val = tensor([true, false])]; tensor var_1609 = slice_by_index(begin = var_1609_begin_0, end = var_1609_end_0, end_mask = var_1609_end_mask_0, squeeze_mask = var_1609_squeeze_mask_0, x = var_1597)[name = tensor("op_1609")]; tensor var_1610_promoted = const()[name = tensor("op_1610_promoted"), val = tensor(0x0p+0)]; tensor var_1611 = mul(x = var_1609, y = var_1610_promoted)[name = tensor("op_1611")]; tensor zero_slice_axes_0 = const()[name = tensor("zero_slice_axes_0"), val = tensor([0])]; tensor zero_slice = expand_dims(axes = zero_slice_axes_0, x = var_1611)[name = tensor("zero_slice")]; tensor var_1615 = const()[name = tensor("op_1615"), val = tensor(0)]; tensor context_feedback_21_interleave_0 = const()[name = tensor("context_feedback_21_interleave_0"), val = tensor(false)]; tensor context_feedback_21 = concat(axis = var_1615, interleave = context_feedback_21_interleave_0, values = (var_1597, zero_slice))[name = tensor("context_feedback_21")]; tensor var_1621_begin_0 = const()[name = tensor("op_1621_begin_0"), val = tensor([3, 0])]; tensor var_1621_end_0 = const()[name = tensor("op_1621_end_0"), val = tensor([9, 1000])]; tensor var_1621_end_mask_0 = const()[name = tensor("op_1621_end_mask_0"), val = tensor([false, true])]; tensor var_1621 = slice_by_index(begin = var_1621_begin_0, end = var_1621_end_0, end_mask = var_1621_end_mask_0, x = context_feedback_21)[name = tensor("op_1621")]; tensor context_feedback_23_perm_0 = const()[name = tensor("context_feedback_23_perm_0"), val = tensor([1, 0])]; tensor var_1625 = const()[name = tensor("op_1625"), val = tensor(-0x1.8p+0)]; tensor context_feedback_23 = transpose(perm = context_feedback_23_perm_0, x = var_1621)[name = tensor("transpose_30")]; tensor var_1626 = greater(x = context_feedback_23, y = var_1625)[name = tensor("op_1626")]; tensor var_1626_promoted_dtype_0 = const()[name = tensor("op_1626_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1630 = const()[name = tensor("op_1630"), val = tensor(-0x1p-1)]; tensor var_1631 = less(x = context_feedback_23, y = var_1630)[name = tensor("op_1631")]; tensor var_1631_promoted_dtype_0 = const()[name = tensor("op_1631_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1631_promoted = cast(dtype = var_1631_promoted_dtype_0, x = var_1631)[name = tensor("cast_180")]; tensor var_1626_promoted = cast(dtype = var_1626_promoted_dtype_0, x = var_1626)[name = tensor("cast_181")]; tensor var_1635 = mul(x = var_1626_promoted, y = var_1631_promoted)[name = tensor("op_1635")]; tensor var_1636_promoted = const()[name = tensor("op_1636_promoted"), val = tensor(0x1p+0)]; tensor not_padded_feedback = sub(x = var_1636_promoted, y = var_1635)[name = tensor("not_padded_feedback")]; tensor var_1639 = const()[name = tensor("op_1639"), val = tensor(-0x1.19999ap+0)]; tensor var_1640 = greater(x = context_feedback_23, y = var_1639)[name = tensor("op_1640")]; tensor var_1640_promoted_dtype_0 = const()[name = tensor("op_1640_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1644 = const()[name = tensor("op_1644"), val = tensor(-0x1.ccccccp-1)]; tensor var_1645 = less(x = context_feedback_23, y = var_1644)[name = tensor("op_1645")]; tensor var_1645_promoted_dtype_0 = const()[name = tensor("op_1645_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1645_promoted = cast(dtype = var_1645_promoted_dtype_0, x = var_1645)[name = tensor("cast_178")]; tensor var_1640_promoted = cast(dtype = var_1640_promoted_dtype_0, x = var_1640)[name = tensor("cast_179")]; tensor is_padding = mul(x = var_1640_promoted, y = var_1645_promoted)[name = tensor("is_padding")]; tensor var_1650 = const()[name = tensor("op_1650"), val = tensor(0x1p+0)]; tensor is_not_padding = sub(x = var_1650, y = is_padding)[name = tensor("is_not_padding")]; tensor var_1655 = mul(x = is_padding, y = var_625_promoted)[name = tensor("op_1655")]; tensor var_1656 = mul(x = is_not_padding, y = context_sigma)[name = tensor("op_1656")]; tensor padded_sigma = add(x = var_1655, y = var_1656)[name = tensor("padded_sigma")]; tensor var_1659 = mul(x = scatter_nd_2, y = is_not_padding)[name = tensor("op_1659")]; tensor var_1661 = sub(x = context_feedback_23, y = var_1659)[name = tensor("op_1661")]; tensor var_1662 = real_div(x = var_1661, y = padded_sigma)[name = tensor("op_1662")]; tensor var_1663 = mul(x = is_not_padding, y = var_1662)[name = tensor("op_1663")]; tensor var_1664 = mul(x = is_padding, y = context_feedback_23)[name = tensor("op_1664")]; tensor context_feedback = add(x = var_1663, y = var_1664)[name = tensor("context_feedback")]; tensor var_1667 = const()[name = tensor("op_1667"), val = tensor(-0x1p-1)]; tensor var_1668 = greater(x = alignments_3, y = var_1667)[name = tensor("op_1668")]; tensor var_1668_promoted_dtype_0 = const()[name = tensor("op_1668_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1673 = const()[name = tensor("op_1673"), val = tensor(0x1p-1)]; tensor var_1674 = sub(x = alignments_3, y = var_1673)[name = tensor("op_1674")]; tensor var_1668_promoted = cast(dtype = var_1668_promoted_dtype_0, x = var_1668)[name = tensor("cast_177")]; tensor var_1675 = mul(x = var_1668_promoted, y = var_1674)[name = tensor("op_1675")]; tensor var_1676_promoted = const()[name = tensor("op_1676_promoted"), val = tensor(0x1p+1)]; tensor alignments = mul(x = var_1675, y = var_1676_promoted)[name = tensor("alignments")]; tensor var_1693_perm_0 = const()[name = tensor("op_1693_perm_0"), val = tensor([1, 0])]; tensor time_context_feedback_begin_0 = const()[name = tensor("time_context_feedback_begin_0"), val = tensor([0, 0])]; tensor time_context_feedback_end_0 = const()[name = tensor("time_context_feedback_end_0"), val = tensor([1, 1000])]; tensor time_context_feedback_end_mask_0 = const()[name = tensor("time_context_feedback_end_mask_0"), val = tensor([false, true])]; tensor time_context_feedback_squeeze_mask_0 = const()[name = tensor("time_context_feedback_squeeze_mask_0"), val = tensor([true, false])]; tensor var_1693 = transpose(perm = var_1693_perm_0, x = context_feedback)[name = tensor("transpose_29")]; tensor time_context_feedback = slice_by_index(begin = time_context_feedback_begin_0, end = time_context_feedback_end_0, end_mask = time_context_feedback_end_mask_0, squeeze_mask = time_context_feedback_squeeze_mask_0, x = var_1693)[name = tensor("time_context_feedback")]; tensor var_1699_perm_0 = const()[name = tensor("op_1699_perm_0"), val = tensor([1, 0])]; tensor not_padded_time_begin_0 = const()[name = tensor("not_padded_time_begin_0"), val = tensor([0, 0])]; tensor not_padded_time_end_0 = const()[name = tensor("not_padded_time_end_0"), val = tensor([1, 1000])]; tensor not_padded_time_end_mask_0 = const()[name = tensor("not_padded_time_end_mask_0"), val = tensor([false, true])]; tensor not_padded_time_squeeze_mask_0 = const()[name = tensor("not_padded_time_squeeze_mask_0"), val = tensor([true, false])]; tensor var_1699 = transpose(perm = var_1699_perm_0, x = not_padded_feedback)[name = tensor("transpose_28")]; tensor not_padded_time = slice_by_index(begin = not_padded_time_begin_0, end = not_padded_time_end_0, end_mask = not_padded_time_end_mask_0, squeeze_mask = not_padded_time_squeeze_mask_0, x = var_1699)[name = tensor("not_padded_time")]; tensor var_1710_begin_0 = const()[name = tensor("op_1710_begin_0"), val = tensor([1, 0])]; tensor var_1710_end_0 = const()[name = tensor("op_1710_end_0"), val = tensor([4, 1000])]; tensor var_1710_end_mask_0 = const()[name = tensor("op_1710_end_mask_0"), val = tensor([false, true])]; tensor var_1710 = slice_by_index(begin = var_1710_begin_0, end = var_1710_end_0, end_mask = var_1710_end_mask_0, x = var_1693)[name = tensor("op_1710")]; tensor location_context_feedback_perm_0 = const()[name = tensor("location_context_feedback_perm_0"), val = tensor([1, 0])]; tensor var_1721_begin_0 = const()[name = tensor("op_1721_begin_0"), val = tensor([1, 0])]; tensor var_1721_end_0 = const()[name = tensor("op_1721_end_0"), val = tensor([4, 1000])]; tensor var_1721_end_mask_0 = const()[name = tensor("op_1721_end_mask_0"), val = tensor([false, true])]; tensor var_1721 = slice_by_index(begin = var_1721_begin_0, end = var_1721_end_0, end_mask = var_1721_end_mask_0, x = var_1699)[name = tensor("op_1721")]; tensor not_padded_location_perm_0 = const()[name = tensor("not_padded_location_perm_0"), val = tensor([1, 0])]; tensor var_1732_begin_0 = const()[name = tensor("op_1732_begin_0"), val = tensor([4, 0])]; tensor var_1732_end_0 = const()[name = tensor("op_1732_end_0"), val = tensor([6, 1000])]; tensor var_1732_end_mask_0 = const()[name = tensor("op_1732_end_mask_0"), val = tensor([true, true])]; tensor var_1732 = slice_by_index(begin = var_1732_begin_0, end = var_1732_end_0, end_mask = var_1732_end_mask_0, x = var_1693)[name = tensor("op_1732")]; tensor freq_context_feedback_perm_0 = const()[name = tensor("freq_context_feedback_perm_0"), val = tensor([1, 0])]; tensor var_1743_begin_0 = const()[name = tensor("op_1743_begin_0"), val = tensor([4, 0])]; tensor var_1743_end_0 = const()[name = tensor("op_1743_end_0"), val = tensor([6, 1000])]; tensor var_1743_end_mask_0 = const()[name = tensor("op_1743_end_mask_0"), val = tensor([true, true])]; tensor var_1743 = slice_by_index(begin = var_1743_begin_0, end = var_1743_end_0, end_mask = var_1743_end_mask_0, x = var_1699)[name = tensor("op_1743")]; tensor not_padded_freq_perm_0 = const()[name = tensor("not_padded_freq_perm_0"), val = tensor([1, 0])]; tensor var_1748 = sub(x = time_context_feedback, y = time_context_1)[name = tensor("op_1748")]; tensor var_1749 = abs(x = var_1748)[name = tensor("op_1749")]; tensor similarity_time = mul(x = var_1749, y = not_padded_time)[name = tensor("similarity_time")]; tensor freq_context_feedback = transpose(perm = freq_context_feedback_perm_0, x = var_1732)[name = tensor("transpose_25")]; tensor var_1752 = sub(x = freq_context_feedback, y = freq_context_1)[name = tensor("op_1752")]; tensor not_padded_freq = transpose(perm = not_padded_freq_perm_0, x = var_1743)[name = tensor("transpose_24")]; tensor input_5 = mul(x = var_1752, y = not_padded_freq)[name = tensor("input_5")]; tensor var_1756 = const()[name = tensor("op_1756"), val = tensor([1])]; tensor var_1757 = const()[name = tensor("op_1757"), val = tensor(false)]; tensor similarity_freq = reduce_l2_norm(axes = var_1756, keep_dims = var_1757, x = input_5)[name = tensor("similarity_freq")]; tensor location_context_feedback = transpose(perm = location_context_feedback_perm_0, x = var_1710)[name = tensor("transpose_27")]; tensor var_1761 = sub(x = location_context_feedback, y = location_context_1)[name = tensor("op_1761")]; tensor not_padded_location = transpose(perm = not_padded_location_perm_0, x = var_1721)[name = tensor("transpose_26")]; tensor input_7 = mul(x = var_1761, y = not_padded_location)[name = tensor("input_7")]; tensor var_1765 = const()[name = tensor("op_1765"), val = tensor([1])]; tensor var_1766 = const()[name = tensor("op_1766"), val = tensor(false)]; tensor similarity_location = reduce_l2_norm(axes = var_1765, keep_dims = var_1766, x = input_7)[name = tensor("similarity_location")]; tensor var_1769 = abs(x = alignments)[name = tensor("op_1769")]; tensor var_1770 = const()[name = tensor("op_1770"), val = tensor(0x1.0624dep-10)]; tensor var_1771 = greater(x = var_1769, y = var_1770)[name = tensor("op_1771")]; tensor var_1771_promoted_dtype_0 = const()[name = tensor("op_1771_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1771_promoted = cast(dtype = var_1771_promoted_dtype_0, x = var_1771)[name = tensor("cast_176")]; tensor var_1775 = mul(x = not_padded_time, y = var_1771_promoted)[name = tensor("op_1775")]; tensor n_time_axes_0 = const()[name = tensor("n_time_axes_0"), val = tensor([0])]; tensor n_time_keep_dims_0 = const()[name = tensor("n_time_keep_dims_0"), val = tensor(false)]; tensor n_time = reduce_sum(axes = n_time_axes_0, keep_dims = n_time_keep_dims_0, x = var_1775)[name = tensor("n_time")]; tensor var_1781 = const()[name = tensor("op_1781"), val = tensor(0x1p-1)]; tensor var_1782 = pow(x = n_time, y = var_1781)[name = tensor("op_1782")]; tensor var_1783_epsilon_0 = const()[name = tensor("op_1783_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_1783 = inverse(epsilon = var_1783_epsilon_0, x = var_1782)[name = tensor("op_1783")]; tensor var_1784 = const()[name = tensor("op_1784"), val = tensor(0x1.2635e6p+3)]; tensor bw_time = mul(x = var_1783, y = var_1784)[name = tensor("bw_time")]; tensor var_1790_axes_0 = const()[name = tensor("op_1790_axes_0"), val = tensor([1])]; tensor var_1790_keep_dims_0 = const()[name = tensor("op_1790_keep_dims_0"), val = tensor(false)]; tensor var_1790 = reduce_sum(axes = var_1790_axes_0, keep_dims = var_1790_keep_dims_0, x = not_padded_freq)[name = tensor("op_1790")]; tensor var_1791 = const()[name = tensor("op_1791"), val = tensor(0x0p+0)]; tensor var_1792 = greater(x = var_1790, y = var_1791)[name = tensor("op_1792")]; tensor var_1792_promoted_dtype_0 = const()[name = tensor("op_1792_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1792_promoted = cast(dtype = var_1792_promoted_dtype_0, x = var_1792)[name = tensor("cast_175")]; tensor var_1793 = mul(x = var_1792_promoted, y = var_1771_promoted)[name = tensor("op_1793")]; tensor n_freq_axes_0 = const()[name = tensor("n_freq_axes_0"), val = tensor([0])]; tensor n_freq_keep_dims_0 = const()[name = tensor("n_freq_keep_dims_0"), val = tensor(false)]; tensor n_freq = reduce_sum(axes = n_freq_axes_0, keep_dims = n_freq_keep_dims_0, x = var_1793)[name = tensor("n_freq")]; tensor var_1799 = const()[name = tensor("op_1799"), val = tensor(0x1p-1)]; tensor var_1800 = pow(x = n_freq, y = var_1799)[name = tensor("op_1800")]; tensor var_1801_epsilon_0 = const()[name = tensor("op_1801_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_1801 = inverse(epsilon = var_1801_epsilon_0, x = var_1800)[name = tensor("op_1801")]; tensor var_1802 = const()[name = tensor("op_1802"), val = tensor(0x1.f80ac8p+2)]; tensor bw_freq = mul(x = var_1801, y = var_1802)[name = tensor("bw_freq")]; tensor var_1808_axes_0 = const()[name = tensor("op_1808_axes_0"), val = tensor([1])]; tensor var_1808_keep_dims_0 = const()[name = tensor("op_1808_keep_dims_0"), val = tensor(false)]; tensor var_1808 = reduce_sum(axes = var_1808_axes_0, keep_dims = var_1808_keep_dims_0, x = not_padded_location)[name = tensor("op_1808")]; tensor var_1809 = const()[name = tensor("op_1809"), val = tensor(0x0p+0)]; tensor var_1810 = greater(x = var_1808, y = var_1809)[name = tensor("op_1810")]; tensor var_1810_promoted_dtype_0 = const()[name = tensor("op_1810_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1810_promoted = cast(dtype = var_1810_promoted_dtype_0, x = var_1810)[name = tensor("cast_174")]; tensor var_1811 = mul(x = var_1810_promoted, y = var_1771_promoted)[name = tensor("op_1811")]; tensor n_location_axes_0 = const()[name = tensor("n_location_axes_0"), val = tensor([0])]; tensor n_location_keep_dims_0 = const()[name = tensor("n_location_keep_dims_0"), val = tensor(false)]; tensor n_location = reduce_sum(axes = n_location_axes_0, keep_dims = n_location_keep_dims_0, x = var_1811)[name = tensor("n_location")]; tensor var_1817 = const()[name = tensor("op_1817"), val = tensor(0x1.e1583ep-2)]; tensor var_1818 = pow(x = n_location, y = var_1817)[name = tensor("op_1818")]; tensor var_1819_epsilon_0 = const()[name = tensor("op_1819_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_1819 = inverse(epsilon = var_1819_epsilon_0, x = var_1818)[name = tensor("op_1819")]; tensor var_1820 = const()[name = tensor("op_1820"), val = tensor(0x1.292be6p+4)]; tensor bw_location = mul(x = var_1819, y = var_1820)[name = tensor("bw_location")]; tensor var_1822_promoted = const()[name = tensor("op_1822_promoted"), val = tensor(0x0p+0)]; tensor var_1823 = greater(x = alignments, y = var_1822_promoted)[name = tensor("op_1823")]; tensor var_1823_promoted_dtype_0 = const()[name = tensor("op_1823_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1823_promoted = cast(dtype = var_1823_promoted_dtype_0, x = var_1823)[name = tensor("cast_173")]; tensor positive_alignments = mul(x = var_1823_promoted, y = alignments)[name = tensor("positive_alignments")]; tensor var_1825 = real_div(x = similarity_location, y = bw_location)[name = tensor("op_1825")]; tensor var_1826_promoted = const()[name = tensor("op_1826_promoted"), val = tensor(0x1p+1)]; tensor var_1827 = pow(x = var_1825, y = var_1826_promoted)[name = tensor("op_1827")]; tensor var_1828_promoted = const()[name = tensor("op_1828_promoted"), val = tensor(-0x1p+0)]; tensor var_1829 = mul(x = var_1827, y = var_1828_promoted)[name = tensor("op_1829")]; tensor location_score = exp(x = var_1829)[name = tensor("location_score")]; tensor var_1831 = real_div(x = similarity_time, y = bw_time)[name = tensor("op_1831")]; tensor var_1832_promoted = const()[name = tensor("op_1832_promoted"), val = tensor(0x1p+1)]; tensor var_1833 = pow(x = var_1831, y = var_1832_promoted)[name = tensor("op_1833")]; tensor var_1834_promoted = const()[name = tensor("op_1834_promoted"), val = tensor(-0x1p+0)]; tensor var_1835 = mul(x = var_1833, y = var_1834_promoted)[name = tensor("op_1835")]; tensor time_score = exp(x = var_1835)[name = tensor("time_score")]; tensor var_1837 = real_div(x = similarity_freq, y = bw_freq)[name = tensor("op_1837")]; tensor var_1838_promoted = const()[name = tensor("op_1838_promoted"), val = tensor(0x1p+1)]; tensor var_1839 = pow(x = var_1837, y = var_1838_promoted)[name = tensor("op_1839")]; tensor var_1840_promoted = const()[name = tensor("op_1840_promoted"), val = tensor(-0x1p+0)]; tensor var_1841 = mul(x = var_1839, y = var_1840_promoted)[name = tensor("op_1841")]; tensor freq_score = exp(x = var_1841)[name = tensor("freq_score")]; tensor var_1843 = mul(x = positive_alignments, y = time_score)[name = tensor("op_1843")]; tensor var_1844 = mul(x = var_1843, y = freq_score)[name = tensor("op_1844")]; tensor concentration_update = mul(x = var_1844, y = location_score)[name = tensor("concentration_update")]; tensor var_1847_axes_0 = const()[name = tensor("op_1847_axes_0"), val = tensor([0])]; tensor var_1847 = expand_dims(axes = var_1847_axes_0, x = concentration_update)[name = tensor("op_1847")]; tensor var_1849_axes_0 = const()[name = tensor("op_1849_axes_0"), val = tensor([0])]; tensor var_1849 = expand_dims(axes = var_1849_axes_0, x = var_1847)[name = tensor("op_1849")]; tensor pair_filtered_concentration = mul(x = var_1849, y = x_7)[name = tensor("pair_filtered_concentration")]; tensor per_candidate_concentration_1_axes_0 = const()[name = tensor("per_candidate_concentration_1_axes_0"), val = tensor([2])]; tensor per_candidate_concentration_1_keep_dims_0 = const()[name = tensor("per_candidate_concentration_1_keep_dims_0"), val = tensor(false)]; tensor per_candidate_concentration_1 = reduce_sum(axes = per_candidate_concentration_1_axes_0, keep_dims = per_candidate_concentration_1_keep_dims_0, x = pair_filtered_concentration)[name = tensor("per_candidate_concentration_1")]; tensor var_1856 = const()[name = tensor("op_1856"), val = tensor(0x1.e93e04p-2)]; tensor per_candidate_concentration = mul(x = per_candidate_concentration_1, y = var_1856)[name = tensor("per_candidate_concentration")]; tensor var_1858 = const()[name = tensor("op_1858"), val = tensor(-0x1.654744p+1)]; tensor var_1859 = mul(x = coverage, y = var_1858)[name = tensor("op_1859")]; tensor var_1860 = exp(x = var_1859)[name = tensor("op_1860")]; tensor discount_axes_0 = const()[name = tensor("discount_axes_0"), val = tensor([1])]; tensor discount = expand_dims(axes = discount_axes_0, x = var_1860)[name = tensor("discount")]; tensor var_1863 = mul(x = per_candidate_concentration, y = discount)[name = tensor("op_1863")]; tensor var_1865 = add(x = var_1154, y = var_1863)[name = tensor("op_1865")]; tensor var_1867 = add(x = var_1865, y = var_1366)[name = tensor("op_1867")]; tensor positive_counts_1 = add(x = var_1867, y = var_1285)[name = tensor("positive_counts_1")]; tensor var_1870 = const()[name = tensor("op_1870"), val = tensor(0x1p-1)]; tensor negative_counts = mul(x = var_1155, y = var_1870)[name = tensor("negative_counts")]; tensor var_1872_promoted = const()[name = tensor("op_1872_promoted"), val = tensor(0x0p+0)]; tensor var_1873 = greater(x = unique_candidates, y = var_1872_promoted)[name = tensor("op_1873")]; tensor var_1873_promoted_dtype_0 = const()[name = tensor("op_1873_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1873_promoted = cast(dtype = var_1873_promoted_dtype_0, x = var_1873)[name = tensor("cast_172")]; tensor var_1877 = mul(x = negative_counts, y = var_1873_promoted)[name = tensor("op_1877")]; tensor inversion_point_axes_0 = const()[name = tensor("inversion_point_axes_0"), val = tensor([0])]; tensor inversion_point_keep_dims_0 = const()[name = tensor("inversion_point_keep_dims_0"), val = tensor(false)]; tensor inversion_point = reduce_sum(axes = inversion_point_axes_0, keep_dims = inversion_point_keep_dims_0, x = var_1877)[name = tensor("inversion_point")]; tensor n_unique_axes_0 = const()[name = tensor("n_unique_axes_0"), val = tensor([0])]; tensor n_unique_keep_dims_0 = const()[name = tensor("n_unique_keep_dims_0"), val = tensor(false)]; tensor n_unique = reduce_sum(axes = n_unique_axes_0, keep_dims = n_unique_keep_dims_0, x = var_1873_promoted)[name = tensor("n_unique")]; tensor positive_others_1 = sub(x = inversion_point, y = negative_counts)[name = tensor("positive_others_1")]; tensor var_1891_promoted = const()[name = tensor("op_1891_promoted"), val = tensor(0x1p+0)]; tensor var_1892 = sub(x = n_unique, y = var_1891_promoted)[name = tensor("op_1892")]; tensor var_1894 = const()[name = tensor("op_1894"), val = tensor(0x1.0624dep-10)]; tensor var_1895 = add(x = var_1892, y = var_1894)[name = tensor("op_1895")]; tensor var_1896 = real_div(x = positive_others_1, y = var_1895)[name = tensor("op_1896")]; tensor var_1897 = const()[name = tensor("op_1897"), val = tensor(0x1.7f100cp+2)]; tensor positive_others = mul(x = var_1896, y = var_1897)[name = tensor("positive_others")]; tensor positive_counts = add(x = positive_counts_1, y = positive_others)[name = tensor("positive_counts")]; tensor var_1901_promoted = const()[name = tensor("op_1901_promoted"), val = tensor(-0x1p+0)]; tensor var_1902 = mul(x = negative_counts, y = var_1901_promoted)[name = tensor("op_1902")]; tensor var_1904 = const()[name = tensor("op_1904"), val = tensor(0x1.0624dep-10)]; tensor var_1905 = add(x = positive_counts, y = var_1904)[name = tensor("op_1905")]; tensor var_1906 = real_div(x = var_1902, y = var_1905)[name = tensor("op_1906")]; tensor var_1907 = exp(x = var_1906)[name = tensor("op_1907")]; tensor var_1908_promoted = const()[name = tensor("op_1908_promoted"), val = tensor(0x1p+0)]; tensor var_1910 = sub(x = var_1908_promoted, y = var_1907)[name = tensor("op_1910")]; tensor to_subtract_1 = mul(x = var_164_to_fp32, y = var_1910)[name = tensor("to_subtract_1")]; tensor to_subtract = mul(x = positive_counts, y = to_subtract_1)[name = tensor("to_subtract")]; tensor var_1914 = sub(x = positive_others, y = to_subtract)[name = tensor("op_1914")]; tensor var_1915 = mul(x = var_164_to_fp32, y = var_1914)[name = tensor("op_1915")]; tensor search_likelihood_begin_0 = const()[name = tensor("search_likelihood_begin_0"), val = tensor([7, 0, 0])]; tensor search_likelihood_end_0 = const()[name = tensor("search_likelihood_end_0"), val = tensor([8, 50, 15])]; tensor search_likelihood_end_mask_0 = const()[name = tensor("search_likelihood_end_mask_0"), val = tensor([false, true, true])]; tensor search_likelihood_squeeze_mask_0 = const()[name = tensor("search_likelihood_squeeze_mask_0"), val = tensor([true, false, false])]; tensor search_likelihood = slice_by_index(begin = search_likelihood_begin_0, end = search_likelihood_end_0, end_mask = search_likelihood_end_mask_0, squeeze_mask = search_likelihood_squeeze_mask_0, x = tuples)[name = tensor("search_likelihood")]; tensor var_1919 = const()[name = tensor("op_1919"), val = tensor(0x1p-1)]; tensor var_1920 = greater(x = unique_candidates, y = var_1919)[name = tensor("op_1920")]; tensor var_1920_promoted_dtype_0 = const()[name = tensor("op_1920_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1924 = const()[name = tensor("op_1924"), val = tensor(-0x1p-1)]; tensor var_1925 = greater(x = search_likelihood, y = var_1924)[name = tensor("op_1925")]; tensor var_1925_promoted_dtype_0 = const()[name = tensor("op_1925_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1920_promoted = cast(dtype = var_1920_promoted_dtype_0, x = var_1920)[name = tensor("cast_171")]; tensor var_1929 = mul(x = search_likelihood, y = var_1920_promoted)[name = tensor("op_1929")]; tensor max_like_keep_dims_0 = const()[name = tensor("max_like_keep_dims_0"), val = tensor(false)]; tensor max_like = reduce_max(keep_dims = max_like_keep_dims_0, x = var_1929)[name = tensor("max_like")]; tensor var_1931_promoted = const()[name = tensor("op_1931_promoted"), val = tensor(0x1p+0)]; tensor var_1933 = sub(x = var_1931_promoted, y = var_1920_promoted)[name = tensor("op_1933")]; tensor var_1934 = const()[name = tensor("op_1934"), val = tensor(0x1.f4p+9)]; tensor var_1935 = mul(x = var_1933, y = var_1934)[name = tensor("op_1935")]; tensor var_1937 = add(x = search_likelihood, y = var_1935)[name = tensor("op_1937")]; tensor min_like_keep_dims_0 = const()[name = tensor("min_like_keep_dims_0"), val = tensor(false)]; tensor min_like = reduce_min(keep_dims = min_like_keep_dims_0, x = var_1937)[name = tensor("min_like")]; tensor var_1940 = sub(x = search_likelihood, y = min_like)[name = tensor("op_1940")]; tensor var_1925_promoted = cast(dtype = var_1925_promoted_dtype_0, x = var_1925)[name = tensor("cast_170")]; tensor search_concentration = mul(x = var_1940, y = var_1925_promoted)[name = tensor("search_concentration")]; tensor var_1942 = mul(x = search_concentration, y = var_1920_promoted)[name = tensor("op_1942")]; tensor like_sum_keep_dims_0 = const()[name = tensor("like_sum_keep_dims_0"), val = tensor(false)]; tensor like_sum = reduce_sum(keep_dims = like_sum_keep_dims_0, x = var_1942)[name = tensor("like_sum")]; tensor var_1946 = sub(x = max_like, y = min_like)[name = tensor("op_1946")]; tensor var_1947 = const()[name = tensor("op_1947"), val = tensor(0x1.0624dep-10)]; tensor var_1948 = less(x = like_sum, y = var_1947)[name = tensor("op_1948")]; tensor var_1949 = const()[name = tensor("op_1949"), val = tensor(0x1.0624dep-10)]; tensor var_1948_promoted_dtype_0 = const()[name = tensor("op_1948_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1948_promoted = cast(dtype = var_1948_promoted_dtype_0, x = var_1948)[name = tensor("cast_169")]; tensor var_1950 = mul(x = var_1948_promoted, y = var_1949)[name = tensor("op_1950")]; tensor var_1952 = add(x = like_sum, y = var_1950)[name = tensor("op_1952")]; tensor var_1953 = real_div(x = var_1946, y = var_1952)[name = tensor("op_1953")]; tensor var_1954 = mul(x = search_concentration, y = var_1953)[name = tensor("op_1954")]; tensor var_1957 = const()[name = tensor("op_1957"), val = tensor(0x1.666666p-1)]; tensor var_1958 = mul(x = var_1954, y = var_1957)[name = tensor("op_1958")]; tensor var_1960 = add(x = var_1366, y = var_1958)[name = tensor("op_1960")]; tensor var_1962 = add(x = positive_counts_1, y = var_1915)[name = tensor("op_1962")]; tensor var_1964 = add(x = var_1962, y = var_1958)[name = tensor("op_1964")]; tensor var_1966 = const()[name = tensor("op_1966"), val = tensor(0x1.19be8p-1)]; tensor var_1967 = add(x = var_1964, y = var_1966)[name = tensor("op_1967")]; tensor var_1969 = add(x = var_1154, y = var_1960)[name = tensor("op_1969")]; tensor var_1971 = add(x = var_1969, y = var_1958)[name = tensor("op_1971")]; tensor var_1973 = const()[name = tensor("op_1973"), val = tensor(0x1.19be8p-1)]; tensor var_1974 = add(x = var_1971, y = var_1973)[name = tensor("op_1974")]; tensor var_1975 = mul(x = var_1967, y = var_164_to_fp32)[name = tensor("op_1975")]; tensor concentration = mul(x = var_1974, y = var_164_to_fp32)[name = tensor("concentration")]; tensor var_1982 = mul(x = var_1873_promoted, y = concentration)[name = tensor("op_1982")]; tensor column_normalization_axes_0 = const()[name = tensor("column_normalization_axes_0"), val = tensor([0])]; tensor column_normalization_keep_dims_0 = const()[name = tensor("column_normalization_keep_dims_0"), val = tensor(false)]; tensor column_normalization = reduce_sum(axes = column_normalization_axes_0, keep_dims = column_normalization_keep_dims_0, x = var_1982)[name = tensor("column_normalization")]; tensor var_1989 = const()[name = tensor("op_1989"), val = tensor(0x1.0624dep-10)]; tensor var_1990 = add(x = column_normalization, y = var_1989)[name = tensor("op_1990")]; tensor dirichlet_mu_1 = real_div(x = concentration, y = var_1990)[name = tensor("dirichlet_mu_1")]; tensor var_1995 = real_div(x = var_1975, y = var_1990)[name = tensor("op_1995")]; tensor var_1996 = const()[name = tensor("op_1996"), val = tensor(0)]; tensor var_1997 = equal(x = var_164_promoted, y = var_1996)[name = tensor("op_1997")]; tensor var_1997_promoted_dtype_0 = const()[name = tensor("op_1997_promoted_dtype_0"), val = tensor("fp32")]; tensor var_1997_promoted = cast(dtype = var_1997_promoted_dtype_0, x = var_1997)[name = tensor("cast_168")]; tensor var_1999 = add(x = var_1995, y = var_1997_promoted)[name = tensor("op_1999")]; tensor var_2000_epsilon_0 = const()[name = tensor("op_2000_epsilon_0"), val = tensor(0x1p-149)]; tensor var_2000 = log(epsilon = var_2000_epsilon_0, x = var_1999)[name = tensor("op_2000")]; tensor var_2005_axes_0 = const()[name = tensor("op_2005_axes_0"), val = tensor([1])]; tensor var_2005_keep_dims_0 = const()[name = tensor("op_2005_keep_dims_0"), val = tensor(false)]; tensor var_2005 = reduce_sum(axes = var_2005_axes_0, keep_dims = var_2005_keep_dims_0, x = var_2000)[name = tensor("op_2005")]; tensor var_2006 = exp(x = var_2005)[name = tensor("op_2006")]; tensor already_prompted = abs(x = alreadyPrompted)[name = tensor("already_prompted")]; tensor is_resolved = abs(x = isResolved)[name = tensor("is_resolved")]; tensor var_174_promoted = cast(dtype = var_174_promoted_dtype_0, x = var_174)[name = tensor("cast_252")]; tensor var_2009 = mul(x = var_2006, y = var_174_promoted)[name = tensor("op_2009")]; tensor var_2011_keep_dims_0 = const()[name = tensor("op_2011_keep_dims_0"), val = tensor(false)]; tensor var_2011 = reduce_sum(keep_dims = var_2011_keep_dims_0, x = var_2009)[name = tensor("op_2011")]; tensor posterior = real_div(x = var_2009, y = var_2011)[name = tensor("posterior")]; tensor var_2030_axes_0 = const()[name = tensor("op_2030_axes_0"), val = tensor([1])]; tensor var_2030_keep_dims_0 = const()[name = tensor("op_2030_keep_dims_0"), val = tensor(false)]; tensor var_2030 = reduce_sum(axes = var_2030_axes_0, keep_dims = var_2030_keep_dims_0, x = x_3)[name = tensor("op_2030")]; tensor var_2032_axes_0 = const()[name = tensor("op_2032_axes_0"), val = tensor([-1])]; tensor var_2032 = expand_dims(axes = var_2032_axes_0, x = var_2030)[name = tensor("op_2032")]; tensor var_2033 = const()[name = tensor("op_2033"), val = tensor(-0x1.86ap+16)]; tensor var_2034 = greater(x = var_2032, y = var_2033)[name = tensor("op_2034")]; tensor var_2034_promoted_dtype_0 = const()[name = tensor("op_2034_promoted_dtype_0"), val = tensor("fp32")]; tensor transpose_1 = const()[name = tensor("transpose_1"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(24640)))]; tensor ones_with_shape_bias_0 = const()[name = tensor("ones_with_shape_bias_0"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(24768)))]; tensor var_2034_promoted = cast(dtype = var_2034_promoted_dtype_0, x = var_2034)[name = tensor("cast_167")]; tensor ones_with_shape = linear(bias = ones_with_shape_bias_0, weight = transpose_1, x = var_2034_promoted)[name = tensor("ones_with_shape")]; tensor var_2039 = const()[name = tensor("op_2039"), val = tensor(1)]; tensor var_2041_exclusive_0 = const()[name = tensor("op_2041_exclusive_0"), val = tensor(false)]; tensor var_2041_reverse_0 = const()[name = tensor("op_2041_reverse_0"), val = tensor(false)]; tensor var_2041 = cumsum(axis = var_2039, exclusive = var_2041_exclusive_0, reverse = var_2041_reverse_0, x = ones_with_shape)[name = tensor("op_2041")]; tensor var_2043_promoted = const()[name = tensor("op_2043_promoted"), val = tensor(0x1p+0)]; tensor initial_index = sub(x = var_2041, y = var_2043_promoted)[name = tensor("initial_index")]; tensor var_2045 = const()[name = tensor("op_2045"), val = tensor(-0x1p-1)]; tensor var_2046 = greater(x = candidate_risk, y = var_2045)[name = tensor("op_2046")]; tensor var_2046_promoted_dtype_0 = const()[name = tensor("op_2046_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2046_promoted = cast(dtype = var_2046_promoted_dtype_0, x = var_2046)[name = tensor("cast_166")]; tensor var_2050 = mul(x = var_2046_promoted, y = candidate_risk)[name = tensor("op_2050")]; tensor var_2051 = const()[name = tensor("op_2051"), val = tensor(0x1p+0)]; tensor var_2053 = sub(x = var_2051, y = var_2046_promoted)[name = tensor("op_2053")]; tensor var_2054_promoted = const()[name = tensor("op_2054_promoted"), val = tensor(0x1.4p+3)]; tensor var_2055 = mul(x = var_2053, y = var_2054_promoted)[name = tensor("op_2055")]; tensor candidate_risk_1 = add(x = var_2050, y = var_2055)[name = tensor("candidate_risk")]; tensor var_2058 = const()[name = tensor("op_2058"), val = tensor(-0x1p-1)]; tensor var_2059 = greater(x = riskLevel, y = var_2058)[name = tensor("op_2059")]; tensor var_2059_promoted_dtype_0 = const()[name = tensor("op_2059_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2059_promoted = cast(dtype = var_2059_promoted_dtype_0, x = var_2059)[name = tensor("cast_165")]; tensor var_2063 = mul(x = var_2059_promoted, y = riskLevel)[name = tensor("op_2063")]; tensor var_2064 = const()[name = tensor("op_2064"), val = tensor(0x1p+0)]; tensor var_2066 = sub(x = var_2064, y = var_2059_promoted)[name = tensor("op_2066")]; tensor var_2067_promoted = const()[name = tensor("op_2067_promoted"), val = tensor(0x1.4p+3)]; tensor var_2068 = mul(x = var_2066, y = var_2067_promoted)[name = tensor("op_2068")]; tensor risk = add(x = var_2063, y = var_2068)[name = tensor("risk")]; tensor reduce_min_1_axes_0 = const()[name = tensor("reduce_min_1_axes_0"), val = tensor([1])]; tensor reduce_min_1_keep_dims_0 = const()[name = tensor("reduce_min_1_keep_dims_0"), val = tensor(false)]; tensor reduce_min_1 = reduce_min(axes = reduce_min_1_axes_0, keep_dims = reduce_min_1_keep_dims_0, x = candidate_risk_1)[name = tensor("reduce_min_1")]; tensor var_2077 = const()[name = tensor("op_2077"), val = tensor([-1, 1])]; tensor input = reshape(shape = var_2077, x = reduce_min_1)[name = tensor("input")]; tensor const_2 = const()[name = tensor("const_2"), val = tensor(0x1p+0)]; tensor padded_highest_candidate_risk_pad_0 = const()[name = tensor("padded_highest_candidate_risk_pad_0"), val = tensor([0, 0, 0, 1])]; tensor padded_highest_candidate_risk_mode_0 = const()[name = tensor("padded_highest_candidate_risk_mode_0"), val = tensor("constant")]; tensor padded_highest_candidate_risk = pad(constant_val = const_2, mode = padded_highest_candidate_risk_mode_0, pad = padded_highest_candidate_risk_pad_0, x = input)[name = tensor("padded_highest_candidate_risk")]; tensor adjusted_risk = minimum(x = padded_highest_candidate_risk, y = risk)[name = tensor("adjusted_risk")]; tensor transpose_2 = const()[name = tensor("transpose_2"), val = tensor([[0x1p+1, 0x1p+0]])]; tensor lookup_index_bias_0 = const()[name = tensor("lookup_index_bias_0"), val = tensor([0x0p+0])]; tensor lookup_index = linear(bias = lookup_index_bias_0, weight = transpose_2, x = adjusted_risk)[name = tensor("lookup_index")]; tensor fill_1 = const()[name = tensor("fill_1"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(24896)))]; tensor var_2107_axes_0 = const()[name = tensor("op_2107_axes_0"), val = tensor([-1])]; tensor var_2107 = squeeze(axes = var_2107_axes_0, x = lookup_index)[name = tensor("op_2107")]; tensor risk_position = mul(x = fill_1, y = var_2107)[name = tensor("risk_position")]; tensor var_2109_perm_0 = const()[name = tensor("op_2109_perm_0"), val = tensor([1, 0])]; tensor var_2109 = transpose(perm = var_2109_perm_0, x = risk_position)[name = tensor("transpose_23")]; tensor var_2111 = sub(x = initial_index, y = var_2109)[name = tensor("op_2111")]; tensor var_2112 = const()[name = tensor("op_2112"), val = tensor(0x0p+0)]; tensor var_2113 = equal(x = var_2111, y = var_2112)[name = tensor("op_2113")]; tensor var_2113_promoted_dtype_0 = const()[name = tensor("op_2113_promoted_dtype_0"), val = tensor("fp32")]; tensor transpose_3 = const()[name = tensor("transpose_3"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(27392)))]; tensor risk_vector_bias_0 = const()[name = tensor("risk_vector_bias_0"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(28736)))]; tensor var_2113_promoted = cast(dtype = var_2113_promoted_dtype_0, x = var_2113)[name = tensor("cast_164")]; tensor risk_vector = linear(bias = risk_vector_bias_0, weight = transpose_3, x = var_2113_promoted)[name = tensor("risk_vector")]; tensor transpose_4 = const()[name = tensor("transpose_4"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(28928)))]; tensor var_2124_bias_0 = const()[name = tensor("op_2124_bias_0"), val = tensor([0x0p+0])]; tensor var_2124 = linear(bias = var_2124_bias_0, weight = transpose_4, x = risk_vector)[name = tensor("op_2124")]; tensor de_happy_1_axes_0 = const()[name = tensor("de_happy_1_axes_0"), val = tensor([-1])]; tensor de_happy_1 = squeeze(axes = de_happy_1_axes_0, x = var_2124)[name = tensor("de_happy_1")]; tensor transpose_5 = const()[name = tensor("transpose_5"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(29120)))]; tensor var_2132_bias_0 = const()[name = tensor("op_2132_bias_0"), val = tensor([0x0p+0])]; tensor var_2132 = linear(bias = var_2132_bias_0, weight = transpose_5, x = risk_vector)[name = tensor("op_2132")]; tensor de_sad_axes_0 = const()[name = tensor("de_sad_axes_0"), val = tensor([-1])]; tensor de_sad = squeeze(axes = de_sad_axes_0, x = var_2132)[name = tensor("de_sad")]; tensor transpose_6 = const()[name = tensor("transpose_6"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(29312)))]; tensor var_2140_bias_0 = const()[name = tensor("op_2140_bias_0"), val = tensor([0x0p+0])]; tensor var_2140 = linear(bias = var_2140_bias_0, weight = transpose_6, x = risk_vector)[name = tensor("op_2140")]; tensor action_confirm_happy_axes_0 = const()[name = tensor("action_confirm_happy_axes_0"), val = tensor([-1])]; tensor action_confirm_happy = squeeze(axes = action_confirm_happy_axes_0, x = var_2140)[name = tensor("action_confirm_happy")]; tensor transpose_7 = const()[name = tensor("transpose_7"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(29504)))]; tensor var_2148_bias_0 = const()[name = tensor("op_2148_bias_0"), val = tensor([0x0p+0])]; tensor var_2148 = linear(bias = var_2148_bias_0, weight = transpose_7, x = risk_vector)[name = tensor("op_2148")]; tensor action_confirm_sad_axes_0 = const()[name = tensor("action_confirm_sad_axes_0"), val = tensor([-1])]; tensor action_confirm_sad = squeeze(axes = action_confirm_sad_axes_0, x = var_2148)[name = tensor("action_confirm_sad")]; tensor transpose_8 = const()[name = tensor("transpose_8"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(29696)))]; tensor var_2156_bias_0 = const()[name = tensor("op_2156_bias_0"), val = tensor([0x0p+0])]; tensor var_2156 = linear(bias = var_2156_bias_0, weight = transpose_8, x = risk_vector)[name = tensor("op_2156")]; tensor param_confirm_happy_axes_0 = const()[name = tensor("param_confirm_happy_axes_0"), val = tensor([-1])]; tensor param_confirm_happy = squeeze(axes = param_confirm_happy_axes_0, x = var_2156)[name = tensor("param_confirm_happy")]; tensor transpose_9 = const()[name = tensor("transpose_9"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(29888)))]; tensor var_2164_bias_0 = const()[name = tensor("op_2164_bias_0"), val = tensor([0x0p+0])]; tensor var_2164 = linear(bias = var_2164_bias_0, weight = transpose_9, x = risk_vector)[name = tensor("op_2164")]; tensor param_confirm_sad_axes_0 = const()[name = tensor("param_confirm_sad_axes_0"), val = tensor([-1])]; tensor param_confirm_sad = squeeze(axes = param_confirm_sad_axes_0, x = var_2164)[name = tensor("param_confirm_sad")]; tensor var_2167 = const()[name = tensor("op_2167"), val = tensor(-0x1.e848p+19)]; tensor var_2168 = greater(x = x_3, y = var_2167)[name = tensor("op_2168")]; tensor var_2169 = const()[name = tensor("op_2169"), val = tensor(0x1.8p+1)]; tensor var_2168_promoted_dtype_0 = const()[name = tensor("op_2168_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2168_promoted = cast(dtype = var_2168_promoted_dtype_0, x = var_2168)[name = tensor("cast_163")]; tensor threes = mul(x = var_2168_promoted, y = var_2169)[name = tensor("threes")]; tensor var_2171_keep_dims_0 = const()[name = tensor("op_2171_keep_dims_0"), val = tensor(false)]; tensor var_2171 = reduce_max(keep_dims = var_2171_keep_dims_0, x = posterior)[name = tensor("op_2171")]; tensor var_2173 = sub(x = posterior, y = var_2171)[name = tensor("op_2173")]; tensor var_2174 = abs(x = var_2173)[name = tensor("op_2174")]; tensor var_2175 = const()[name = tensor("op_2175"), val = tensor(0x1.0624dep-10)]; tensor var_2176 = less(x = var_2174, y = var_2175)[name = tensor("op_2176")]; tensor var_2178_axes_0 = const()[name = tensor("op_2178_axes_0"), val = tensor([-1])]; tensor var_2178 = expand_dims(axes = var_2178_axes_0, x = var_2176)[name = tensor("op_2178")]; tensor var_2178_promoted_dtype_0 = const()[name = tensor("op_2178_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2178_promoted = cast(dtype = var_2178_promoted_dtype_0, x = var_2178)[name = tensor("cast_162")]; tensor de = mul(x = threes, y = var_2178_promoted)[name = tensor("de")]; tensor var_2180 = const()[name = tensor("op_2180"), val = tensor(0x1.a36e2ep-14)]; tensor de_happy = add(x = de_happy_1, y = var_2180)[name = tensor("de_happy")]; tensor var_2187_axes_0 = const()[name = tensor("op_2187_axes_0"), val = tensor([0])]; tensor var_2187_keep_dims_0 = const()[name = tensor("op_2187_keep_dims_0"), val = tensor(false)]; tensor var_2187 = reduce_sum(axes = var_2187_axes_0, keep_dims = var_2187_keep_dims_0, x = de)[name = tensor("op_2187")]; tensor reduce_max_4_axes_0 = const()[name = tensor("reduce_max_4_axes_0"), val = tensor([0])]; tensor reduce_max_4_keep_dims_0 = const()[name = tensor("reduce_max_4_keep_dims_0"), val = tensor(false)]; tensor reduce_max_4 = reduce_max(axes = reduce_max_4_axes_0, keep_dims = reduce_max_4_keep_dims_0, x = var_2187)[name = tensor("reduce_max_4")]; tensor var_2192 = const()[name = tensor("op_2192"), val = tensor(0x1.8p+1)]; tensor var_2193 = less_equal(x = reduce_max_4, y = var_2192)[name = tensor("op_2193")]; tensor var_2193_promoted_dtype_0 = const()[name = tensor("op_2193_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2196 = mul(x = de_happy, y = posterior)[name = tensor("op_2196")]; tensor var_2197_promoted = const()[name = tensor("op_2197_promoted"), val = tensor(0x1p+0)]; tensor var_2199 = sub(x = var_2197_promoted, y = posterior)[name = tensor("op_2199")]; tensor var_2200 = mul(x = de_sad, y = var_2199)[name = tensor("op_2200")]; tensor var_2202 = add(x = var_2196, y = var_2200)[name = tensor("op_2202")]; tensor var_2193_promoted = cast(dtype = var_2193_promoted_dtype_0, x = var_2193)[name = tensor("cast_161")]; tensor de_ev_tuple_1 = mul(x = var_2193_promoted, y = var_2202)[name = tensor("de_ev_tuple_1")]; tensor de_ev_tuple = mul(x = de_ev_tuple_1, y = var_174_promoted)[name = tensor("de_ev_tuple")]; tensor de_ev_keep_dims_0 = const()[name = tensor("de_ev_keep_dims_0"), val = tensor(false)]; tensor de_ev = reduce_max(keep_dims = de_ev_keep_dims_0, x = de_ev_tuple)[name = tensor("de_ev")]; tensor var_2208_keep_dims_0 = const()[name = tensor("op_2208_keep_dims_0"), val = tensor(false)]; tensor var_2208_axis_0 = const()[name = tensor("op_2208_axis_0"), val = tensor(-1)]; tensor var_2208 = reduce_argmax(axis = var_2208_axis_0, keep_dims = var_2208_keep_dims_0, x = de_ev_tuple)[name = tensor("op_2208")]; tensor var_2210 = const()[name = tensor("op_2210"), val = tensor(0x1p+0)]; tensor var_2208_promoted_dtype_0 = const()[name = tensor("op_2208_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2208_promoted = cast(dtype = var_2208_promoted_dtype_0, x = var_2208)[name = tensor("cast_160")]; tensor de_index = add(x = var_2208_promoted, y = var_2210)[name = tensor("de_index")]; tensor de_id_promoted = const()[name = tensor("de_id_promoted"), val = tensor([0x1p+0, 0x1p+0])]; tensor reshape_31 = const()[name = tensor("reshape_31"), val = tensor([1])]; tensor reshape_32_shape_0 = const()[name = tensor("reshape_32_shape_0"), val = tensor([-1])]; tensor reshape_32 = reshape(shape = reshape_32_shape_0, x = de_index)[name = tensor("reshape_32")]; tensor scatter_6_mode_0 = const()[name = tensor("scatter_6_mode_0"), val = tensor("update")]; tensor scatter_6_axis_0 = const()[name = tensor("scatter_6_axis_0"), val = tensor(0)]; tensor scatter_6 = scatter(axis = scatter_6_axis_0, data = de_id_promoted, indices = reshape_31, mode = scatter_6_mode_0, updates = reshape_32)[name = tensor("scatter_6")]; tensor var_2220 = const()[name = tensor("op_2220"), val = tensor(0x1p+1)]; tensor twos_1 = mul(x = var_2168_promoted, y = var_2220)[name = tensor("twos_1")]; tensor var_2223_promoted = const()[name = tensor("op_2223_promoted"), val = tensor(0x1p+1)]; tensor var_2224 = sub(x = component, y = var_2223_promoted)[name = tensor("op_2224")]; tensor var_2225 = abs(x = var_2224)[name = tensor("op_2225")]; tensor var_2226 = const()[name = tensor("op_2226"), val = tensor(0x1.0624dep-10)]; tensor var_2227 = less(x = var_2225, y = var_2226)[name = tensor("op_2227")]; tensor var_2227_promoted_dtype_0 = const()[name = tensor("op_2227_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2238_promoted = const()[name = tensor("op_2238_promoted"), val = tensor(0x1p+0)]; tensor var_2239 = sub(x = var_80, y = var_2238_promoted)[name = tensor("op_2239")]; tensor var_2240 = abs(x = var_2239)[name = tensor("op_2240")]; tensor var_2241 = const()[name = tensor("op_2241"), val = tensor(0x1.0624dep-10)]; tensor var_2242 = less(x = var_2240, y = var_2241)[name = tensor("op_2242")]; tensor var_2242_promoted_dtype_0 = const()[name = tensor("op_2242_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2246 = const()[name = tensor("op_2246"), val = tensor(0x1.0624dep-10)]; tensor var_2247 = less(x = concentration, y = var_2246)[name = tensor("op_2247")]; tensor var_2248 = const()[name = tensor("op_2248"), val = tensor(10)]; tensor var_2247_promoted_dtype_0 = const()[name = tensor("op_2247_promoted_dtype_0"), val = tensor("int32")]; tensor var_2247_promoted = cast(dtype = var_2247_promoted_dtype_0, x = var_2247)[name = tensor("cast_157")]; tensor dirichlet_zeros = mul(x = var_2247_promoted, y = var_2248)[name = tensor("dirichlet_zeros")]; tensor dirichlet_zeros_promoted_dtype_0 = const()[name = tensor("dirichlet_zeros_promoted_dtype_0"), val = tensor("fp32")]; tensor dirichlet_zeros_promoted = cast(dtype = dirichlet_zeros_promoted_dtype_0, x = dirichlet_zeros)[name = tensor("cast_156")]; tensor var_2251 = add(x = concentration, y = dirichlet_zeros_promoted)[name = tensor("op_2251")]; tensor var_2252_keep_dims_0 = const()[name = tensor("op_2252_keep_dims_0"), val = tensor(false)]; tensor var_2252 = reduce_min(keep_dims = var_2252_keep_dims_0, x = var_2251)[name = tensor("op_2252")]; tensor var_2253 = const()[name = tensor("op_2253"), val = tensor(0x1.5851ecp+1)]; tensor insufficient_confidence = less(x = var_2252, y = var_2253)[name = tensor("insufficient_confidence")]; tensor var_2242_promoted = cast(dtype = var_2242_promoted_dtype_0, x = var_2242)[name = tensor("cast_158")]; tensor var_2227_promoted = cast(dtype = var_2227_promoted_dtype_0, x = var_2227)[name = tensor("cast_159")]; tensor var_2255 = mul(x = var_2227_promoted, y = var_2242_promoted)[name = tensor("op_2255")]; tensor insufficient_confidence_promoted_dtype_0 = const()[name = tensor("insufficient_confidence_promoted_dtype_0"), val = tensor("fp32")]; tensor insufficient_confidence_promoted = cast(dtype = insufficient_confidence_promoted_dtype_0, x = insufficient_confidence)[name = tensor("cast_155")]; tensor single_confirm_check = mul(x = var_2255, y = insufficient_confidence_promoted)[name = tensor("single_confirm_check")]; tensor var_2258 = add(x = forced_1, y = single_confirm_check)[name = tensor("op_2258")]; tensor var_2259_promoted = const()[name = tensor("op_2259_promoted"), val = tensor(0x0p+0)]; tensor var_2260 = greater(x = var_2258, y = var_2259_promoted)[name = tensor("op_2260")]; tensor var_2260_promoted_dtype_0 = const()[name = tensor("op_2260_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2264 = abs(x = component)[name = tensor("op_2264")]; tensor var_2265 = const()[name = tensor("op_2265"), val = tensor(0x1.0624dep-10)]; tensor var_2266 = less(x = var_2264, y = var_2265)[name = tensor("op_2266")]; tensor var_2266_promoted_dtype_0 = const()[name = tensor("op_2266_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2266_promoted = cast(dtype = var_2266_promoted_dtype_0, x = var_2266)[name = tensor("cast_153")]; tensor var_2270 = mul(x = var_2266_promoted, y = scatter_6)[name = tensor("op_2270")]; tensor var_2271 = const()[name = tensor("op_2271"), val = tensor(0x1p+0)]; tensor var_2273 = sub(x = var_2271, y = var_2266_promoted)[name = tensor("op_2273")]; tensor var_2274_promoted = const()[name = tensor("op_2274_promoted"), val = tensor([0x1p+1, 0x0p+0])]; tensor var_2275 = mul(x = var_2273, y = var_2274_promoted)[name = tensor("op_2275")]; tensor ac_id = add(x = var_2270, y = var_2275)[name = tensor("ac_id")]; tensor ac = mul(x = twos_1, y = var_2178_promoted)[name = tensor("ac")]; tensor var_2291_axes_0 = const()[name = tensor("op_2291_axes_0"), val = tensor([0])]; tensor var_2291_keep_dims_0 = const()[name = tensor("op_2291_keep_dims_0"), val = tensor(false)]; tensor var_2291 = reduce_sum(axes = var_2291_axes_0, keep_dims = var_2291_keep_dims_0, x = ac)[name = tensor("op_2291")]; tensor reduce_max_5_axes_0 = const()[name = tensor("reduce_max_5_axes_0"), val = tensor([0])]; tensor reduce_max_5_keep_dims_0 = const()[name = tensor("reduce_max_5_keep_dims_0"), val = tensor(false)]; tensor reduce_max_5 = reduce_max(axes = reduce_max_5_axes_0, keep_dims = reduce_max_5_keep_dims_0, x = var_2291)[name = tensor("reduce_max_5")]; tensor var_2296 = const()[name = tensor("op_2296"), val = tensor(0x1p+1)]; tensor var_2297 = less_equal(x = reduce_max_5, y = var_2296)[name = tensor("op_2297")]; tensor var_2297_promoted_dtype_0 = const()[name = tensor("op_2297_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2300 = mul(x = action_confirm_happy, y = posterior)[name = tensor("op_2300")]; tensor var_2304 = mul(x = action_confirm_sad, y = var_2199)[name = tensor("op_2304")]; tensor var_2306 = add(x = var_2300, y = var_2304)[name = tensor("op_2306")]; tensor var_2297_promoted = cast(dtype = var_2297_promoted_dtype_0, x = var_2297)[name = tensor("cast_152")]; tensor ac_ev_tuple_1 = mul(x = var_2297_promoted, y = var_2306)[name = tensor("ac_ev_tuple_1")]; tensor ac_ev_tuple = mul(x = ac_ev_tuple_1, y = var_174_promoted)[name = tensor("ac_ev_tuple")]; tensor ac_ev_1_keep_dims_0 = const()[name = tensor("ac_ev_1_keep_dims_0"), val = tensor(false)]; tensor ac_ev_1 = reduce_max(keep_dims = ac_ev_1_keep_dims_0, x = ac_ev_tuple)[name = tensor("ac_ev_1")]; tensor var_2310_promoted = const()[name = tensor("op_2310_promoted"), val = tensor(0x0p+0)]; tensor var_2311 = greater(x = already_prompted, y = var_2310_promoted)[name = tensor("op_2311")]; tensor var_2311_promoted_dtype_0 = const()[name = tensor("op_2311_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2316_keep_dims_0 = const()[name = tensor("op_2316_keep_dims_0"), val = tensor(false)]; tensor var_2311_promoted = cast(dtype = var_2311_promoted_dtype_0, x = var_2311)[name = tensor("cast_151")]; tensor var_2316 = reduce_sum(keep_dims = var_2316_keep_dims_0, x = var_2311_promoted)[name = tensor("op_2316")]; tensor var_2317 = const()[name = tensor("op_2317"), val = tensor(-0x1.388p+13)]; tensor var_2318 = greater(x = already_prompted, y = var_2317)[name = tensor("op_2318")]; tensor var_2318_promoted_dtype_0 = const()[name = tensor("op_2318_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2323_keep_dims_0 = const()[name = tensor("op_2323_keep_dims_0"), val = tensor(false)]; tensor var_2318_promoted = cast(dtype = var_2318_promoted_dtype_0, x = var_2318)[name = tensor("cast_150")]; tensor var_2323 = reduce_sum(keep_dims = var_2323_keep_dims_0, x = var_2318_promoted)[name = tensor("op_2323")]; tensor var_2324 = equal(x = var_2316, y = var_2323)[name = tensor("op_2324")]; tensor var_2324_promoted_dtype_0 = const()[name = tensor("op_2324_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2328_promoted = const()[name = tensor("op_2328_promoted"), val = tensor(0x1p+0)]; tensor var_2324_promoted = cast(dtype = var_2324_promoted_dtype_0, x = var_2324)[name = tensor("cast_149")]; tensor var_2330 = sub(x = var_2328_promoted, y = var_2324_promoted)[name = tensor("op_2330")]; tensor ac_ev = mul(x = ac_ev_1, y = var_2330)[name = tensor("ac_ev")]; tensor var_2334_keep_dims_0 = const()[name = tensor("op_2334_keep_dims_0"), val = tensor(false)]; tensor var_2334_axis_0 = const()[name = tensor("op_2334_axis_0"), val = tensor(-1)]; tensor var_2334 = reduce_argmax(axis = var_2334_axis_0, keep_dims = var_2334_keep_dims_0, x = ac_ev_tuple)[name = tensor("op_2334")]; tensor var_2336 = const()[name = tensor("op_2336"), val = tensor(0x1p+0)]; tensor var_2334_promoted_dtype_0 = const()[name = tensor("op_2334_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2334_promoted = cast(dtype = var_2334_promoted_dtype_0, x = var_2334)[name = tensor("cast_148")]; tensor ac_index = add(x = var_2334_promoted, y = var_2336)[name = tensor("ac_index")]; tensor reshape_36 = const()[name = tensor("reshape_36"), val = tensor([1])]; tensor reshape_37_shape_0 = const()[name = tensor("reshape_37_shape_0"), val = tensor([-1])]; tensor reshape_37 = reshape(shape = reshape_37_shape_0, x = ac_index)[name = tensor("reshape_37")]; tensor scatter_7_mode_0 = const()[name = tensor("scatter_7_mode_0"), val = tensor("update")]; tensor scatter_7_axis_0 = const()[name = tensor("scatter_7_axis_0"), val = tensor(0)]; tensor scatter_7 = scatter(axis = scatter_7_axis_0, data = ac_id, indices = reshape_36, mode = scatter_7_mode_0, updates = reshape_37)[name = tensor("scatter_7")]; tensor var_2343 = const()[name = tensor("op_2343"), val = tensor(0x1.86ap+16)]; tensor var_2260_promoted = cast(dtype = var_2260_promoted_dtype_0, x = var_2260)[name = tensor("cast_154")]; tensor var_2344 = mul(x = var_2260_promoted, y = var_2343)[name = tensor("op_2344")]; tensor action_confirm_ev = add(x = ac_ev, y = var_2344)[name = tensor("action_confirm_ev")]; tensor var_2367_axes_0 = const()[name = tensor("op_2367_axes_0"), val = tensor([1])]; tensor var_2367_keep_dims_0 = const()[name = tensor("op_2367_keep_dims_0"), val = tensor(false)]; tensor var_2367 = reduce_sum(axes = var_2367_axes_0, keep_dims = var_2367_keep_dims_0, x = ac)[name = tensor("op_2367")]; tensor var_2368_promoted = const()[name = tensor("op_2368_promoted"), val = tensor(0x0p+0)]; tensor var_2369 = greater(x = var_2367, y = var_2368_promoted)[name = tensor("op_2369")]; tensor var_2369_promoted_dtype_0 = const()[name = tensor("op_2369_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2369_promoted = cast(dtype = var_2369_promoted_dtype_0, x = var_2369)[name = tensor("cast_147")]; tensor var_2373 = mul(x = var_2369_promoted, y = param_confirm_happy)[name = tensor("op_2373")]; tensor pc_happy_keep_dims_0 = const()[name = tensor("pc_happy_keep_dims_0"), val = tensor(false)]; tensor pc_happy = reduce_max(keep_dims = pc_happy_keep_dims_0, x = var_2373)[name = tensor("pc_happy")]; tensor var_2375 = mul(x = var_2369_promoted, y = param_confirm_sad)[name = tensor("op_2375")]; tensor pc_sad_keep_dims_0 = const()[name = tensor("pc_sad_keep_dims_0"), val = tensor(false)]; tensor pc_sad = reduce_max(keep_dims = pc_sad_keep_dims_0, x = var_2375)[name = tensor("pc_sad")]; tensor reshape_60_shape_0 = const()[name = tensor("reshape_60_shape_0"), val = tensor([50, 1])]; tensor reshape_60 = reshape(shape = reshape_60_shape_0, x = var_2369_promoted)[name = tensor("reshape_60")]; tensor var_2380 = mul(x = dirichlet_mu_1, y = reshape_60)[name = tensor("op_2380")]; tensor pc_mu_flat_axes_0 = const()[name = tensor("pc_mu_flat_axes_0"), val = tensor([0])]; tensor pc_mu_flat_keep_dims_0 = const()[name = tensor("pc_mu_flat_keep_dims_0"), val = tensor(false)]; tensor pc_mu_flat = reduce_sum(axes = pc_mu_flat_axes_0, keep_dims = pc_mu_flat_keep_dims_0, x = var_2380)[name = tensor("pc_mu_flat")]; tensor var_2402 = mul(x = pc_happy, y = pc_mu_flat)[name = tensor("op_2402")]; tensor var_2403_promoted = const()[name = tensor("op_2403_promoted"), val = tensor(0x1p+0)]; tensor var_2405 = sub(x = var_2403_promoted, y = pc_mu_flat)[name = tensor("op_2405")]; tensor var_2406 = mul(x = pc_sad, y = var_2405)[name = tensor("op_2406")]; tensor var_2408 = add(x = var_2402, y = var_2406)[name = tensor("op_2408")]; tensor pc_ev_1 = mul(x = var_2297_promoted, y = var_2408)[name = tensor("pc_ev_1")]; tensor var_2410 = const()[name = tensor("op_2410"), val = tensor(0x1.388p+13)]; tensor var_2411 = mul(x = already_prompted, y = var_2410)[name = tensor("op_2411")]; tensor var_2413 = add(x = pc_ev_1, y = var_2411)[name = tensor("op_2413")]; tensor var_2414 = const()[name = tensor("op_2414"), val = tensor(0x1.388p+13)]; tensor var_2415 = mul(x = is_resolved, y = var_2414)[name = tensor("op_2415")]; tensor var_2417 = add(x = var_2413, y = var_2415)[name = tensor("op_2417")]; tensor var_184_promoted = cast(dtype = var_184_promoted_dtype_0, x = var_184)[name = tensor("cast_250")]; tensor var_2418 = mul(x = var_184_promoted, y = var_2417)[name = tensor("op_2418")]; tensor var_2419 = const()[name = tensor("op_2419"), val = tensor(0x1p+0)]; tensor var_2421 = sub(x = var_2419, y = var_184_promoted)[name = tensor("op_2421")]; tensor var_2422 = const()[name = tensor("op_2422"), val = tensor(0x1.388p+13)]; tensor var_2423 = mul(x = var_2421, y = var_2422)[name = tensor("op_2423")]; tensor pc_ev = add(x = var_2418, y = var_2423)[name = tensor("pc_ev")]; tensor var_2426_keep_dims_0 = const()[name = tensor("op_2426_keep_dims_0"), val = tensor(false)]; tensor var_2426 = reduce_min(keep_dims = var_2426_keep_dims_0, x = pc_ev)[name = tensor("op_2426")]; tensor pc_ev_min_1 = abs(x = var_2426)[name = tensor("pc_ev_min_1")]; tensor var_2446_promoted = const()[name = tensor("op_2446_promoted"), val = tensor(0x0p+0)]; tensor var_2447 = greater(x = is_resolved, y = var_2446_promoted)[name = tensor("op_2447")]; tensor var_2447_promoted_dtype_0 = const()[name = tensor("op_2447_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2452_keep_dims_0 = const()[name = tensor("op_2452_keep_dims_0"), val = tensor(false)]; tensor var_2447_promoted = cast(dtype = var_2447_promoted_dtype_0, x = var_2447)[name = tensor("cast_146")]; tensor var_2452 = reduce_sum(keep_dims = var_2452_keep_dims_0, x = var_2447_promoted)[name = tensor("op_2452")]; tensor var_2453 = const()[name = tensor("op_2453"), val = tensor(-0x1.388p+13)]; tensor var_2454 = greater(x = is_resolved, y = var_2453)[name = tensor("op_2454")]; tensor var_2454_promoted_dtype_0 = const()[name = tensor("op_2454_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2459_keep_dims_0 = const()[name = tensor("op_2459_keep_dims_0"), val = tensor(false)]; tensor var_2454_promoted = cast(dtype = var_2454_promoted_dtype_0, x = var_2454)[name = tensor("cast_145")]; tensor var_2459 = reduce_sum(keep_dims = var_2459_keep_dims_0, x = var_2454_promoted)[name = tensor("op_2459")]; tensor var_2460 = equal(x = var_2452, y = var_2459)[name = tensor("op_2460")]; tensor var_2460_promoted_dtype_0 = const()[name = tensor("op_2460_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2467 = mul(x = pc_ev_min_1, y = var_2330)[name = tensor("op_2467")]; tensor var_2468_promoted = const()[name = tensor("op_2468_promoted"), val = tensor(0x1p+0)]; tensor var_2460_promoted = cast(dtype = var_2460_promoted_dtype_0, x = var_2460)[name = tensor("cast_144")]; tensor var_2470 = sub(x = var_2468_promoted, y = var_2460_promoted)[name = tensor("op_2470")]; tensor pc_ev_min = mul(x = var_2467, y = var_2470)[name = tensor("pc_ev_min")]; tensor var_2473_keep_dims_0 = const()[name = tensor("op_2473_keep_dims_0"), val = tensor(false)]; tensor var_2473 = reduce_sum(keep_dims = var_2473_keep_dims_0, x = var_174_promoted)[name = tensor("op_2473")]; tensor var_2474 = const()[name = tensor("op_2474"), val = tensor(0x1.8p+0)]; tensor var_2475 = less(x = var_2473, y = var_2474)[name = tensor("op_2475")]; tensor var_2475_promoted_dtype_0 = const()[name = tensor("op_2475_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2475_promoted = cast(dtype = var_2475_promoted_dtype_0, x = var_2475)[name = tensor("cast_143")]; tensor var_2476 = mul(x = forced_parameter_confirm, y = var_2475_promoted)[name = tensor("op_2476")]; tensor var_2479 = equal(x = pc_ev, y = pc_ev_min)[name = tensor("op_2479")]; tensor var_2480 = const()[name = tensor("op_2480"), val = tensor(0)]; tensor var_2479_promoted_dtype_0 = const()[name = tensor("op_2479_promoted_dtype_0"), val = tensor("int32")]; tensor var_2479_promoted = cast(dtype = var_2479_promoted_dtype_0, x = var_2479)[name = tensor("cast_142")]; tensor var_2481 = greater(x = var_2479_promoted, y = var_2480)[name = tensor("op_2481")]; tensor var_2481_promoted_dtype_0 = const()[name = tensor("op_2481_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2481_promoted = cast(dtype = var_2481_promoted_dtype_0, x = var_2481)[name = tensor("cast_141")]; tensor var_2482 = mul(x = var_2380, y = var_2481_promoted)[name = tensor("op_2482")]; tensor var_2486_promoted = const()[name = tensor("op_2486_promoted"), val = tensor(0x0p+0)]; tensor var_2487 = greater(x = var_2482, y = var_2486_promoted)[name = tensor("op_2487")]; tensor var_2487_promoted_dtype_0 = const()[name = tensor("op_2487_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2487_promoted = cast(dtype = var_2487_promoted_dtype_0, x = var_2487)[name = tensor("cast_140")]; tensor pc = mul(x = twos_1, y = var_2487_promoted)[name = tensor("pc")]; tensor var_2496_axes_0 = const()[name = tensor("op_2496_axes_0"), val = tensor([1])]; tensor var_2496_keep_dims_0 = const()[name = tensor("op_2496_keep_dims_0"), val = tensor(false)]; tensor var_2496 = reduce_sum(axes = var_2496_axes_0, keep_dims = var_2496_keep_dims_0, x = pc)[name = tensor("op_2496")]; tensor reduce_max_7_axes_0 = const()[name = tensor("reduce_max_7_axes_0"), val = tensor([0])]; tensor reduce_max_7_keep_dims_0 = const()[name = tensor("reduce_max_7_keep_dims_0"), val = tensor(false)]; tensor reduce_max_7 = reduce_max(axes = reduce_max_7_axes_0, keep_dims = reduce_max_7_keep_dims_0, x = var_2496)[name = tensor("reduce_max_7")]; tensor var_2501 = const()[name = tensor("op_2501"), val = tensor(0x1p+1)]; tensor var_2502 = less_equal(x = reduce_max_7, y = var_2501)[name = tensor("op_2502")]; tensor var_2502_promoted_dtype_0 = const()[name = tensor("op_2502_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2502_promoted = cast(dtype = var_2502_promoted_dtype_0, x = var_2502)[name = tensor("cast_139")]; tensor var_2508 = mul(x = pc, y = var_2502_promoted)[name = tensor("op_2508")]; tensor var_2509_promoted = const()[name = tensor("op_2509_promoted"), val = tensor(0x1p+0)]; tensor var_2511 = sub(x = var_2509_promoted, y = var_2502_promoted)[name = tensor("op_2511")]; tensor var_2512_promoted = const()[name = tensor("op_2512_promoted"), val = tensor(-0x1p+0)]; tensor var_2513 = mul(x = pc, y = var_2512_promoted)[name = tensor("op_2513")]; tensor var_2514 = const()[name = tensor("op_2514"), val = tensor(-1)]; tensor logical_not_6 = const()[name = tensor("logical_not_6"), val = tensor(true)]; tensor var_2516 = argsort(ascending = logical_not_6, axis = var_2514, x = var_2513)[name = tensor("op_2516")]; tensor var_2517 = const()[name = tensor("op_2517"), val = tensor(0)]; tensor var_2518 = equal(x = var_2516, y = var_2517)[name = tensor("op_2518")]; tensor var_2518_promoted_dtype_0 = const()[name = tensor("op_2518_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2518_promoted = cast(dtype = var_2518_promoted_dtype_0, x = var_2518)[name = tensor("cast_138")]; tensor var_2519 = mul(x = var_2511, y = var_2518_promoted)[name = tensor("op_2519")]; tensor var_2520 = mul(x = var_2519, y = pc)[name = tensor("op_2520")]; tensor param_confirm = add(x = var_2508, y = var_2520)[name = tensor("param_confirm")]; tensor var_2523_promoted_dtype_0 = const()[name = tensor("op_2523_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2529_keep_dims_0 = const()[name = tensor("op_2529_keep_dims_0"), val = tensor(false)]; tensor var_2529_axis_0 = const()[name = tensor("op_2529_axis_0"), val = tensor(-1)]; tensor var_2523_promoted = cast(dtype = var_2523_promoted_dtype_0, x = var_2479)[name = tensor("cast_137")]; tensor var_2529 = reduce_argmax(axis = var_2529_axis_0, keep_dims = var_2529_keep_dims_0, x = var_2523_promoted)[name = tensor("op_2529")]; tensor var_2531 = const()[name = tensor("op_2531"), val = tensor(1)]; tensor which_col = add(x = var_2529, y = var_2531)[name = tensor("which_col")]; tensor param_confirm_id = const()[name = tensor("param_confirm_id"), val = tensor([3, 2])]; tensor reshape_41 = const()[name = tensor("reshape_41"), val = tensor([1])]; tensor reshape_42_shape_0 = const()[name = tensor("reshape_42_shape_0"), val = tensor([-1])]; tensor reshape_42 = reshape(shape = reshape_42_shape_0, x = which_col)[name = tensor("reshape_42")]; tensor scatter_8_mode_0 = const()[name = tensor("scatter_8_mode_0"), val = tensor("update")]; tensor scatter_8_axis_0 = const()[name = tensor("scatter_8_axis_0"), val = tensor(0)]; tensor scatter_8 = scatter(axis = scatter_8_axis_0, data = param_confirm_id, indices = reshape_41, mode = scatter_8_mode_0, updates = reshape_42)[name = tensor("scatter_8")]; tensor var_2539 = const()[name = tensor("op_2539"), val = tensor(0x1.86ap+16)]; tensor var_2540 = mul(x = var_2476, y = var_2539)[name = tensor("op_2540")]; tensor param_confirm_ev = add(x = pc_ev_min, y = var_2540)[name = tensor("param_confirm_ev")]; tensor var_2546 = abs(x = unique_candidates)[name = tensor("op_2546")]; tensor var_2547 = const()[name = tensor("op_2547"), val = tensor(0x1.0c6f7ap-20)]; tensor var_2548 = greater(x = var_2546, y = var_2547)[name = tensor("op_2548")]; tensor var_2548_promoted_dtype_0 = const()[name = tensor("op_2548_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2548_promoted = cast(dtype = var_2548_promoted_dtype_0, x = var_2548)[name = tensor("cast_136")]; tensor unique_mu = mul(x = dirichlet_mu_1, y = var_2548_promoted)[name = tensor("unique_mu")]; tensor var_2550_promoted = const()[name = tensor("op_2550_promoted"), val = tensor(-0x1p+0)]; tensor var_2551 = mul(x = unique_mu, y = var_2550_promoted)[name = tensor("op_2551")]; tensor var_2552 = const()[name = tensor("op_2552"), val = tensor(0)]; tensor logical_not_7 = const()[name = tensor("logical_not_7"), val = tensor(true)]; tensor j = argsort(ascending = logical_not_7, axis = var_2552, x = var_2551)[name = tensor("j")]; tensor var_2556 = const()[name = tensor("op_2556"), val = tensor(0)]; tensor by_mu = gather_along_axis(axis = var_2556, indices = j, x = unique_mu)[name = tensor("by_mu")]; tensor var_2559 = const()[name = tensor("op_2559"), val = tensor(0)]; tensor cdf_exclusive_0 = const()[name = tensor("cdf_exclusive_0"), val = tensor(false)]; tensor cdf_reverse_0 = const()[name = tensor("cdf_reverse_0"), val = tensor(false)]; tensor cdf = cumsum(axis = var_2559, exclusive = cdf_exclusive_0, reverse = cdf_reverse_0, x = by_mu)[name = tensor("cdf")]; tensor var_2562_promoted = const()[name = tensor("op_2562_promoted"), val = tensor(0x0p+0)]; tensor var_2563 = equal(x = unique_mu, y = var_2562_promoted)[name = tensor("op_2563")]; tensor var_2563_promoted_dtype_0 = const()[name = tensor("op_2563_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2571_axes_0 = const()[name = tensor("op_2571_axes_0"), val = tensor([0])]; tensor var_2571_keep_dims_0 = const()[name = tensor("op_2571_keep_dims_0"), val = tensor(false)]; tensor var_2563_promoted = cast(dtype = var_2563_promoted_dtype_0, x = var_2563)[name = tensor("cast_135")]; tensor var_2571 = reduce_sum(axes = var_2571_axes_0, keep_dims = var_2571_keep_dims_0, x = var_2563_promoted)[name = tensor("op_2571")]; tensor var_2573_promoted = const()[name = tensor("op_2573_promoted"), val = tensor(0x1p+0)]; tensor num_dups = add(x = var_2571, y = var_2573_promoted)[name = tensor("num_dups")]; tensor var_2575 = const()[name = tensor("op_2575"), val = tensor(-0x1.86ap+16)]; tensor var_2576 = greater(x = unique_mu, y = var_2575)[name = tensor("op_2576")]; tensor cast_64_dtype_0 = const()[name = tensor("cast_64_dtype_0"), val = tensor("fp32")]; tensor num_tuples_axes_0 = const()[name = tensor("num_tuples_axes_0"), val = tensor([0])]; tensor num_tuples_keep_dims_0 = const()[name = tensor("num_tuples_keep_dims_0"), val = tensor(false)]; tensor cast_64 = cast(dtype = cast_64_dtype_0, x = var_2576)[name = tensor("cast_134")]; tensor num_tuples = reduce_sum(axes = num_tuples_axes_0, keep_dims = num_tuples_keep_dims_0, x = cast_64)[name = tensor("num_tuples")]; tensor var_2582 = equal(x = num_dups, y = num_tuples)[name = tensor("op_2582")]; tensor var_2582_promoted_dtype_0 = const()[name = tensor("op_2582_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2586_promoted = const()[name = tensor("op_2586_promoted"), val = tensor(0x1p+0)]; tensor var_2582_promoted = cast(dtype = var_2582_promoted_dtype_0, x = var_2582)[name = tensor("cast_133")]; tensor one_candidate_disambig_1 = sub(x = var_2586_promoted, y = var_2582_promoted)[name = tensor("one_candidate_disambig_1")]; tensor one_candidate_disambig_axes_0 = const()[name = tensor("one_candidate_disambig_axes_0"), val = tensor([0])]; tensor one_candidate_disambig = expand_dims(axes = one_candidate_disambig_axes_0, x = one_candidate_disambig_1)[name = tensor("one_candidate_disambig")]; tensor var_2591 = const()[name = tensor("op_2591"), val = tensor(-0x1.388p+13)]; tensor var_2592 = greater(x = unique_mu, y = var_2591)[name = tensor("op_2592")]; tensor var_2592_promoted_dtype_0 = const()[name = tensor("op_2592_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2596 = const()[name = tensor("op_2596"), val = tensor(0)]; tensor indices_1_exclusive_0 = const()[name = tensor("indices_1_exclusive_0"), val = tensor(false)]; tensor indices_1_reverse_0 = const()[name = tensor("indices_1_reverse_0"), val = tensor(false)]; tensor var_2592_promoted = cast(dtype = var_2592_promoted_dtype_0, x = var_2592)[name = tensor("cast_132")]; tensor indices_1 = cumsum(axis = var_2596, exclusive = indices_1_exclusive_0, reverse = indices_1_reverse_0, x = var_2592_promoted)[name = tensor("indices_1")]; tensor var_2599 = const()[name = tensor("op_2599"), val = tensor(0x1p+1)]; tensor var_2600 = equal(x = indices_1, y = var_2599)[name = tensor("op_2600")]; tensor var_2600_promoted_dtype_0 = const()[name = tensor("op_2600_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2604_promoted = const()[name = tensor("op_2604_promoted"), val = tensor(0x1p+0)]; tensor var_2605 = greater(x = indices_1, y = var_2604_promoted)[name = tensor("op_2605")]; tensor var_2605_promoted_dtype_0 = const()[name = tensor("op_2605_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2609_promoted = const()[name = tensor("op_2609_promoted"), val = tensor(0x1p+2)]; tensor var_2610 = less(x = indices_1, y = var_2609_promoted)[name = tensor("op_2610")]; tensor var_2610_promoted_dtype_0 = const()[name = tensor("op_2610_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2610_promoted = cast(dtype = var_2610_promoted_dtype_0, x = var_2610)[name = tensor("cast_129")]; tensor var_2605_promoted = cast(dtype = var_2605_promoted_dtype_0, x = var_2605)[name = tensor("cast_130")]; tensor ind_2_3 = mul(x = var_2605_promoted, y = var_2610_promoted)[name = tensor("ind_2_3")]; tensor var_2615_promoted = const()[name = tensor("op_2615_promoted"), val = tensor(0x0p+0)]; tensor var_2616 = greater(x = indices_1, y = var_2615_promoted)[name = tensor("op_2616")]; tensor var_2616_promoted_dtype_0 = const()[name = tensor("op_2616_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2620_promoted = const()[name = tensor("op_2620_promoted"), val = tensor(0x1.8p+2)]; tensor var_2621 = less(x = indices_1, y = var_2620_promoted)[name = tensor("op_2621")]; tensor var_2621_promoted_dtype_0 = const()[name = tensor("op_2621_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2621_promoted = cast(dtype = var_2621_promoted_dtype_0, x = var_2621)[name = tensor("cast_127")]; tensor var_2616_promoted = cast(dtype = var_2616_promoted_dtype_0, x = var_2616)[name = tensor("cast_128")]; tensor ind_1_5 = mul(x = var_2616_promoted, y = var_2621_promoted)[name = tensor("ind_1_5")]; tensor var_2631_promoted = const()[name = tensor("op_2631_promoted"), val = tensor(0x1.5p+4)]; tensor var_2632 = less(x = indices_1, y = var_2631_promoted)[name = tensor("op_2632")]; tensor var_2632_promoted_dtype_0 = const()[name = tensor("op_2632_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2632_promoted = cast(dtype = var_2632_promoted_dtype_0, x = var_2632)[name = tensor("cast_126")]; tensor ind_2_20 = mul(x = var_2605_promoted, y = var_2632_promoted)[name = tensor("ind_2_20")]; tensor ind_2_5 = mul(x = var_2605_promoted, y = var_2621_promoted)[name = tensor("ind_2_5")]; tensor var_2648_promoted = const()[name = tensor("op_2648_promoted"), val = tensor(0x1p+1)]; tensor var_2649 = greater(x = indices_1, y = var_2648_promoted)[name = tensor("op_2649")]; tensor var_2649_promoted_dtype_0 = const()[name = tensor("op_2649_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2649_promoted = cast(dtype = var_2649_promoted_dtype_0, x = var_2649)[name = tensor("cast_125")]; tensor ind_3_5 = mul(x = var_2649_promoted, y = var_2621_promoted)[name = tensor("ind_3_5")]; tensor var_2659_promoted = const()[name = tensor("op_2659_promoted"), val = tensor(0x1.8p+1)]; tensor var_2660 = greater(x = indices_1, y = var_2659_promoted)[name = tensor("op_2660")]; tensor var_2660_promoted_dtype_0 = const()[name = tensor("op_2660_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2660_promoted = cast(dtype = var_2660_promoted_dtype_0, x = var_2660)[name = tensor("cast_124")]; tensor ind_4_5 = mul(x = var_2660_promoted, y = var_2621_promoted)[name = tensor("ind_4_5")]; tensor var_2670_promoted = const()[name = tensor("op_2670_promoted"), val = tensor(0x1.4p+2)]; tensor var_2671 = greater(x = indices_1, y = var_2670_promoted)[name = tensor("op_2671")]; tensor var_2671_promoted_dtype_0 = const()[name = tensor("op_2671_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2671_promoted = cast(dtype = var_2671_promoted_dtype_0, x = var_2671)[name = tensor("cast_123")]; tensor ind_6_20 = mul(x = var_2671_promoted, y = var_2632_promoted)[name = tensor("ind_6_20")]; tensor var_2686_promoted = const()[name = tensor("op_2686_promoted"), val = tensor(0x1.4p+4)]; tensor var_2687 = greater(x = indices_1, y = var_2686_promoted)[name = tensor("op_2687")]; tensor var_2687_promoted_dtype_0 = const()[name = tensor("op_2687_promoted_dtype_0"), val = tensor("fp32")]; tensor risk_transposed_perm_0 = const()[name = tensor("risk_transposed_perm_0"), val = tensor([1, 0])]; tensor var_2696_begin_0 = const()[name = tensor("op_2696_begin_0"), val = tensor([3, 0])]; tensor var_2696_end_0 = const()[name = tensor("op_2696_end_0"), val = tensor([4, 50])]; tensor var_2696_end_mask_0 = const()[name = tensor("op_2696_end_mask_0"), val = tensor([false, true])]; tensor var_2696_squeeze_mask_0 = const()[name = tensor("op_2696_squeeze_mask_0"), val = tensor([true, false])]; tensor risk_transposed = transpose(perm = risk_transposed_perm_0, x = risk_vector)[name = tensor("transpose_22")]; tensor var_2696 = slice_by_index(begin = var_2696_begin_0, end = var_2696_end_0, end_mask = var_2696_end_mask_0, squeeze_mask = var_2696_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2696")]; tensor var_2698_axes_0 = const()[name = tensor("op_2698_axes_0"), val = tensor([-1])]; tensor var_2698 = expand_dims(axes = var_2698_axes_0, x = var_2696)[name = tensor("op_2698")]; tensor var_2600_promoted = cast(dtype = var_2600_promoted_dtype_0, x = var_2600)[name = tensor("cast_131")]; tensor var_2699 = mul(x = var_2600_promoted, y = var_2698)[name = tensor("op_2699")]; tensor var_2700 = mul(x = var_2699, y = cdf)[name = tensor("op_2700")]; tensor var_2703_begin_0 = const()[name = tensor("op_2703_begin_0"), val = tensor([22, 0])]; tensor var_2703_end_0 = const()[name = tensor("op_2703_end_0"), val = tensor([23, 50])]; tensor var_2703_end_mask_0 = const()[name = tensor("op_2703_end_mask_0"), val = tensor([false, true])]; tensor var_2703_squeeze_mask_0 = const()[name = tensor("op_2703_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2703 = slice_by_index(begin = var_2703_begin_0, end = var_2703_end_0, end_mask = var_2703_end_mask_0, squeeze_mask = var_2703_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2703")]; tensor var_2705_axes_0 = const()[name = tensor("op_2705_axes_0"), val = tensor([-1])]; tensor var_2705 = expand_dims(axes = var_2705_axes_0, x = var_2703)[name = tensor("op_2705")]; tensor var_2706 = mul(x = var_2600_promoted, y = var_2705)[name = tensor("op_2706")]; tensor var_2707 = const()[name = tensor("op_2707"), val = tensor(0x1p+0)]; tensor var_2709 = sub(x = var_2707, y = cdf)[name = tensor("op_2709")]; tensor var_2710 = mul(x = var_2706, y = var_2709)[name = tensor("op_2710")]; tensor var_2712 = add(x = var_2700, y = var_2710)[name = tensor("op_2712")]; tensor var_2715_begin_0 = const()[name = tensor("op_2715_begin_0"), val = tensor([4, 0])]; tensor var_2715_end_0 = const()[name = tensor("op_2715_end_0"), val = tensor([5, 50])]; tensor var_2715_end_mask_0 = const()[name = tensor("op_2715_end_mask_0"), val = tensor([false, true])]; tensor var_2715_squeeze_mask_0 = const()[name = tensor("op_2715_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2715 = slice_by_index(begin = var_2715_begin_0, end = var_2715_end_0, end_mask = var_2715_end_mask_0, squeeze_mask = var_2715_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2715")]; tensor var_2717_axes_0 = const()[name = tensor("op_2717_axes_0"), val = tensor([-1])]; tensor var_2717 = expand_dims(axes = var_2717_axes_0, x = var_2715)[name = tensor("op_2717")]; tensor var_2718 = mul(x = ind_2_3, y = var_2717)[name = tensor("op_2718")]; tensor var_2719 = mul(x = var_2718, y = cdf)[name = tensor("op_2719")]; tensor var_2721 = add(x = var_2712, y = var_2719)[name = tensor("op_2721")]; tensor var_2724_begin_0 = const()[name = tensor("op_2724_begin_0"), val = tensor([21, 0])]; tensor var_2724_end_0 = const()[name = tensor("op_2724_end_0"), val = tensor([22, 50])]; tensor var_2724_end_mask_0 = const()[name = tensor("op_2724_end_mask_0"), val = tensor([false, true])]; tensor var_2724_squeeze_mask_0 = const()[name = tensor("op_2724_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2724 = slice_by_index(begin = var_2724_begin_0, end = var_2724_end_0, end_mask = var_2724_end_mask_0, squeeze_mask = var_2724_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2724")]; tensor var_2726_axes_0 = const()[name = tensor("op_2726_axes_0"), val = tensor([-1])]; tensor var_2726 = expand_dims(axes = var_2726_axes_0, x = var_2724)[name = tensor("op_2726")]; tensor var_2727 = mul(x = ind_2_3, y = var_2726)[name = tensor("op_2727")]; tensor var_2731 = mul(x = var_2727, y = var_2709)[name = tensor("op_2731")]; tensor var_2733 = add(x = var_2721, y = var_2731)[name = tensor("op_2733")]; tensor var_2736_begin_0 = const()[name = tensor("op_2736_begin_0"), val = tensor([5, 0])]; tensor var_2736_end_0 = const()[name = tensor("op_2736_end_0"), val = tensor([6, 50])]; tensor var_2736_end_mask_0 = const()[name = tensor("op_2736_end_mask_0"), val = tensor([false, true])]; tensor var_2736_squeeze_mask_0 = const()[name = tensor("op_2736_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2736 = slice_by_index(begin = var_2736_begin_0, end = var_2736_end_0, end_mask = var_2736_end_mask_0, squeeze_mask = var_2736_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2736")]; tensor var_2738_axes_0 = const()[name = tensor("op_2738_axes_0"), val = tensor([-1])]; tensor var_2738 = expand_dims(axes = var_2738_axes_0, x = var_2736)[name = tensor("op_2738")]; tensor var_2739 = mul(x = ind_1_5, y = var_2738)[name = tensor("op_2739")]; tensor var_2740 = mul(x = var_2739, y = cdf)[name = tensor("op_2740")]; tensor var_2742 = add(x = var_2733, y = var_2740)[name = tensor("op_2742")]; tensor var_2745_begin_0 = const()[name = tensor("op_2745_begin_0"), val = tensor([20, 0])]; tensor var_2745_end_0 = const()[name = tensor("op_2745_end_0"), val = tensor([21, 50])]; tensor var_2745_end_mask_0 = const()[name = tensor("op_2745_end_mask_0"), val = tensor([false, true])]; tensor var_2745_squeeze_mask_0 = const()[name = tensor("op_2745_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2745 = slice_by_index(begin = var_2745_begin_0, end = var_2745_end_0, end_mask = var_2745_end_mask_0, squeeze_mask = var_2745_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2745")]; tensor var_2747_axes_0 = const()[name = tensor("op_2747_axes_0"), val = tensor([-1])]; tensor var_2747 = expand_dims(axes = var_2747_axes_0, x = var_2745)[name = tensor("op_2747")]; tensor var_2748 = mul(x = ind_1_5, y = var_2747)[name = tensor("op_2748")]; tensor var_2752 = mul(x = var_2748, y = var_2709)[name = tensor("op_2752")]; tensor var_2754 = add(x = var_2742, y = var_2752)[name = tensor("op_2754")]; tensor var_2757_begin_0 = const()[name = tensor("op_2757_begin_0"), val = tensor([6, 0])]; tensor var_2757_end_0 = const()[name = tensor("op_2757_end_0"), val = tensor([7, 50])]; tensor var_2757_end_mask_0 = const()[name = tensor("op_2757_end_mask_0"), val = tensor([false, true])]; tensor var_2757_squeeze_mask_0 = const()[name = tensor("op_2757_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2757 = slice_by_index(begin = var_2757_begin_0, end = var_2757_end_0, end_mask = var_2757_end_mask_0, squeeze_mask = var_2757_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2757")]; tensor var_2759_axes_0 = const()[name = tensor("op_2759_axes_0"), val = tensor([-1])]; tensor var_2759 = expand_dims(axes = var_2759_axes_0, x = var_2757)[name = tensor("op_2759")]; tensor var_2760 = mul(x = ind_2_5, y = var_2759)[name = tensor("op_2760")]; tensor var_2761 = mul(x = var_2760, y = cdf)[name = tensor("op_2761")]; tensor var_2763 = add(x = var_2754, y = var_2761)[name = tensor("op_2763")]; tensor var_2766_begin_0 = const()[name = tensor("op_2766_begin_0"), val = tensor([19, 0])]; tensor var_2766_end_0 = const()[name = tensor("op_2766_end_0"), val = tensor([20, 50])]; tensor var_2766_end_mask_0 = const()[name = tensor("op_2766_end_mask_0"), val = tensor([false, true])]; tensor var_2766_squeeze_mask_0 = const()[name = tensor("op_2766_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2766 = slice_by_index(begin = var_2766_begin_0, end = var_2766_end_0, end_mask = var_2766_end_mask_0, squeeze_mask = var_2766_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2766")]; tensor var_2768_axes_0 = const()[name = tensor("op_2768_axes_0"), val = tensor([-1])]; tensor var_2768 = expand_dims(axes = var_2768_axes_0, x = var_2766)[name = tensor("op_2768")]; tensor var_2769 = mul(x = ind_2_5, y = var_2768)[name = tensor("op_2769")]; tensor var_2773 = mul(x = var_2769, y = var_2709)[name = tensor("op_2773")]; tensor var_2775 = add(x = var_2763, y = var_2773)[name = tensor("op_2775")]; tensor var_2778_begin_0 = const()[name = tensor("op_2778_begin_0"), val = tensor([7, 0])]; tensor var_2778_end_0 = const()[name = tensor("op_2778_end_0"), val = tensor([8, 50])]; tensor var_2778_end_mask_0 = const()[name = tensor("op_2778_end_mask_0"), val = tensor([false, true])]; tensor var_2778_squeeze_mask_0 = const()[name = tensor("op_2778_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2778 = slice_by_index(begin = var_2778_begin_0, end = var_2778_end_0, end_mask = var_2778_end_mask_0, squeeze_mask = var_2778_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2778")]; tensor var_2780_axes_0 = const()[name = tensor("op_2780_axes_0"), val = tensor([-1])]; tensor var_2780 = expand_dims(axes = var_2780_axes_0, x = var_2778)[name = tensor("op_2780")]; tensor var_2781 = mul(x = ind_3_5, y = var_2780)[name = tensor("op_2781")]; tensor var_2782 = mul(x = var_2781, y = cdf)[name = tensor("op_2782")]; tensor var_2784 = add(x = var_2775, y = var_2782)[name = tensor("op_2784")]; tensor var_2787_begin_0 = const()[name = tensor("op_2787_begin_0"), val = tensor([18, 0])]; tensor var_2787_end_0 = const()[name = tensor("op_2787_end_0"), val = tensor([19, 50])]; tensor var_2787_end_mask_0 = const()[name = tensor("op_2787_end_mask_0"), val = tensor([false, true])]; tensor var_2787_squeeze_mask_0 = const()[name = tensor("op_2787_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2787 = slice_by_index(begin = var_2787_begin_0, end = var_2787_end_0, end_mask = var_2787_end_mask_0, squeeze_mask = var_2787_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2787")]; tensor var_2789_axes_0 = const()[name = tensor("op_2789_axes_0"), val = tensor([-1])]; tensor var_2789 = expand_dims(axes = var_2789_axes_0, x = var_2787)[name = tensor("op_2789")]; tensor var_2790 = mul(x = ind_3_5, y = var_2789)[name = tensor("op_2790")]; tensor var_2794 = mul(x = var_2790, y = var_2709)[name = tensor("op_2794")]; tensor var_2796 = add(x = var_2784, y = var_2794)[name = tensor("op_2796")]; tensor var_2799_begin_0 = const()[name = tensor("op_2799_begin_0"), val = tensor([8, 0])]; tensor var_2799_end_0 = const()[name = tensor("op_2799_end_0"), val = tensor([9, 50])]; tensor var_2799_end_mask_0 = const()[name = tensor("op_2799_end_mask_0"), val = tensor([false, true])]; tensor var_2799_squeeze_mask_0 = const()[name = tensor("op_2799_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2799 = slice_by_index(begin = var_2799_begin_0, end = var_2799_end_0, end_mask = var_2799_end_mask_0, squeeze_mask = var_2799_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2799")]; tensor var_2801_axes_0 = const()[name = tensor("op_2801_axes_0"), val = tensor([-1])]; tensor var_2801 = expand_dims(axes = var_2801_axes_0, x = var_2799)[name = tensor("op_2801")]; tensor var_2802 = mul(x = ind_4_5, y = var_2801)[name = tensor("op_2802")]; tensor var_2803 = mul(x = var_2802, y = cdf)[name = tensor("op_2803")]; tensor var_2805 = add(x = var_2796, y = var_2803)[name = tensor("op_2805")]; tensor var_2808_begin_0 = const()[name = tensor("op_2808_begin_0"), val = tensor([17, 0])]; tensor var_2808_end_0 = const()[name = tensor("op_2808_end_0"), val = tensor([18, 50])]; tensor var_2808_end_mask_0 = const()[name = tensor("op_2808_end_mask_0"), val = tensor([false, true])]; tensor var_2808_squeeze_mask_0 = const()[name = tensor("op_2808_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2808 = slice_by_index(begin = var_2808_begin_0, end = var_2808_end_0, end_mask = var_2808_end_mask_0, squeeze_mask = var_2808_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2808")]; tensor var_2810_axes_0 = const()[name = tensor("op_2810_axes_0"), val = tensor([-1])]; tensor var_2810 = expand_dims(axes = var_2810_axes_0, x = var_2808)[name = tensor("op_2810")]; tensor var_2811 = mul(x = ind_4_5, y = var_2810)[name = tensor("op_2811")]; tensor var_2815 = mul(x = var_2811, y = var_2709)[name = tensor("op_2815")]; tensor var_2817 = add(x = var_2805, y = var_2815)[name = tensor("op_2817")]; tensor var_2820_begin_0 = const()[name = tensor("op_2820_begin_0"), val = tensor([9, 0])]; tensor var_2820_end_0 = const()[name = tensor("op_2820_end_0"), val = tensor([10, 50])]; tensor var_2820_end_mask_0 = const()[name = tensor("op_2820_end_mask_0"), val = tensor([false, true])]; tensor var_2820_squeeze_mask_0 = const()[name = tensor("op_2820_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2820 = slice_by_index(begin = var_2820_begin_0, end = var_2820_end_0, end_mask = var_2820_end_mask_0, squeeze_mask = var_2820_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2820")]; tensor var_2822_axes_0 = const()[name = tensor("op_2822_axes_0"), val = tensor([-1])]; tensor var_2822 = expand_dims(axes = var_2822_axes_0, x = var_2820)[name = tensor("op_2822")]; tensor var_2823 = mul(x = ind_6_20, y = var_2822)[name = tensor("op_2823")]; tensor var_2824 = mul(x = var_2823, y = cdf)[name = tensor("op_2824")]; tensor var_2826 = add(x = var_2817, y = var_2824)[name = tensor("op_2826")]; tensor var_2829_begin_0 = const()[name = tensor("op_2829_begin_0"), val = tensor([16, 0])]; tensor var_2829_end_0 = const()[name = tensor("op_2829_end_0"), val = tensor([17, 50])]; tensor var_2829_end_mask_0 = const()[name = tensor("op_2829_end_mask_0"), val = tensor([false, true])]; tensor var_2829_squeeze_mask_0 = const()[name = tensor("op_2829_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2829 = slice_by_index(begin = var_2829_begin_0, end = var_2829_end_0, end_mask = var_2829_end_mask_0, squeeze_mask = var_2829_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2829")]; tensor var_2831_axes_0 = const()[name = tensor("op_2831_axes_0"), val = tensor([-1])]; tensor var_2831 = expand_dims(axes = var_2831_axes_0, x = var_2829)[name = tensor("op_2831")]; tensor var_2832 = mul(x = ind_6_20, y = var_2831)[name = tensor("op_2832")]; tensor var_2836 = mul(x = var_2832, y = var_2709)[name = tensor("op_2836")]; tensor var_2838 = add(x = var_2826, y = var_2836)[name = tensor("op_2838")]; tensor var_2841_begin_0 = const()[name = tensor("op_2841_begin_0"), val = tensor([10, 0])]; tensor var_2841_end_0 = const()[name = tensor("op_2841_end_0"), val = tensor([11, 50])]; tensor var_2841_end_mask_0 = const()[name = tensor("op_2841_end_mask_0"), val = tensor([false, true])]; tensor var_2841_squeeze_mask_0 = const()[name = tensor("op_2841_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2841 = slice_by_index(begin = var_2841_begin_0, end = var_2841_end_0, end_mask = var_2841_end_mask_0, squeeze_mask = var_2841_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2841")]; tensor var_2843_axes_0 = const()[name = tensor("op_2843_axes_0"), val = tensor([-1])]; tensor var_2843 = expand_dims(axes = var_2843_axes_0, x = var_2841)[name = tensor("op_2843")]; tensor var_2844 = mul(x = var_2671_promoted, y = var_2843)[name = tensor("op_2844")]; tensor var_2845 = mul(x = var_2844, y = cdf)[name = tensor("op_2845")]; tensor var_2847 = add(x = var_2838, y = var_2845)[name = tensor("op_2847")]; tensor var_2850_begin_0 = const()[name = tensor("op_2850_begin_0"), val = tensor([15, 0])]; tensor var_2850_end_0 = const()[name = tensor("op_2850_end_0"), val = tensor([16, 50])]; tensor var_2850_end_mask_0 = const()[name = tensor("op_2850_end_mask_0"), val = tensor([false, true])]; tensor var_2850_squeeze_mask_0 = const()[name = tensor("op_2850_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2850 = slice_by_index(begin = var_2850_begin_0, end = var_2850_end_0, end_mask = var_2850_end_mask_0, squeeze_mask = var_2850_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2850")]; tensor var_2852_axes_0 = const()[name = tensor("op_2852_axes_0"), val = tensor([-1])]; tensor var_2852 = expand_dims(axes = var_2852_axes_0, x = var_2850)[name = tensor("op_2852")]; tensor var_2853 = mul(x = var_2671_promoted, y = var_2852)[name = tensor("op_2853")]; tensor var_2857 = mul(x = var_2853, y = var_2709)[name = tensor("op_2857")]; tensor var_2859 = add(x = var_2847, y = var_2857)[name = tensor("op_2859")]; tensor var_2862_begin_0 = const()[name = tensor("op_2862_begin_0"), val = tensor([11, 0])]; tensor var_2862_end_0 = const()[name = tensor("op_2862_end_0"), val = tensor([12, 50])]; tensor var_2862_end_mask_0 = const()[name = tensor("op_2862_end_mask_0"), val = tensor([false, true])]; tensor var_2862_squeeze_mask_0 = const()[name = tensor("op_2862_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2862 = slice_by_index(begin = var_2862_begin_0, end = var_2862_end_0, end_mask = var_2862_end_mask_0, squeeze_mask = var_2862_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2862")]; tensor var_2864_axes_0 = const()[name = tensor("op_2864_axes_0"), val = tensor([-1])]; tensor var_2864 = expand_dims(axes = var_2864_axes_0, x = var_2862)[name = tensor("op_2864")]; tensor var_2865 = mul(x = ind_2_20, y = var_2864)[name = tensor("op_2865")]; tensor var_2866 = mul(x = var_2865, y = cdf)[name = tensor("op_2866")]; tensor var_2868 = add(x = var_2859, y = var_2866)[name = tensor("op_2868")]; tensor var_2871_begin_0 = const()[name = tensor("op_2871_begin_0"), val = tensor([14, 0])]; tensor var_2871_end_0 = const()[name = tensor("op_2871_end_0"), val = tensor([15, 50])]; tensor var_2871_end_mask_0 = const()[name = tensor("op_2871_end_mask_0"), val = tensor([false, true])]; tensor var_2871_squeeze_mask_0 = const()[name = tensor("op_2871_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2871 = slice_by_index(begin = var_2871_begin_0, end = var_2871_end_0, end_mask = var_2871_end_mask_0, squeeze_mask = var_2871_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2871")]; tensor var_2873_axes_0 = const()[name = tensor("op_2873_axes_0"), val = tensor([-1])]; tensor var_2873 = expand_dims(axes = var_2873_axes_0, x = var_2871)[name = tensor("op_2873")]; tensor var_2874 = mul(x = ind_2_20, y = var_2873)[name = tensor("op_2874")]; tensor var_2878 = mul(x = var_2874, y = var_2709)[name = tensor("op_2878")]; tensor var_2880 = add(x = var_2868, y = var_2878)[name = tensor("op_2880")]; tensor var_2883_begin_0 = const()[name = tensor("op_2883_begin_0"), val = tensor([12, 0])]; tensor var_2883_end_0 = const()[name = tensor("op_2883_end_0"), val = tensor([13, 50])]; tensor var_2883_end_mask_0 = const()[name = tensor("op_2883_end_mask_0"), val = tensor([false, true])]; tensor var_2883_squeeze_mask_0 = const()[name = tensor("op_2883_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2883 = slice_by_index(begin = var_2883_begin_0, end = var_2883_end_0, end_mask = var_2883_end_mask_0, squeeze_mask = var_2883_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2883")]; tensor var_2885_axes_0 = const()[name = tensor("op_2885_axes_0"), val = tensor([-1])]; tensor var_2885 = expand_dims(axes = var_2885_axes_0, x = var_2883)[name = tensor("op_2885")]; tensor var_2687_promoted = cast(dtype = var_2687_promoted_dtype_0, x = var_2687)[name = tensor("cast_122")]; tensor var_2886 = mul(x = var_2687_promoted, y = var_2885)[name = tensor("op_2886")]; tensor var_2887 = mul(x = var_2886, y = cdf)[name = tensor("op_2887")]; tensor var_2889 = add(x = var_2880, y = var_2887)[name = tensor("op_2889")]; tensor var_2892_begin_0 = const()[name = tensor("op_2892_begin_0"), val = tensor([13, 0])]; tensor var_2892_end_0 = const()[name = tensor("op_2892_end_0"), val = tensor([14, 50])]; tensor var_2892_end_mask_0 = const()[name = tensor("op_2892_end_mask_0"), val = tensor([false, true])]; tensor var_2892_squeeze_mask_0 = const()[name = tensor("op_2892_squeeze_mask_0"), val = tensor([true, false])]; tensor var_2892 = slice_by_index(begin = var_2892_begin_0, end = var_2892_end_0, end_mask = var_2892_end_mask_0, squeeze_mask = var_2892_squeeze_mask_0, x = risk_transposed)[name = tensor("op_2892")]; tensor var_2894_axes_0 = const()[name = tensor("op_2894_axes_0"), val = tensor([-1])]; tensor var_2894 = expand_dims(axes = var_2894_axes_0, x = var_2892)[name = tensor("op_2894")]; tensor var_2895 = mul(x = var_2687_promoted, y = var_2894)[name = tensor("op_2895")]; tensor var_2899 = mul(x = var_2895, y = var_2709)[name = tensor("op_2899")]; tensor disambig_ev_cell_1 = add(x = var_2889, y = var_2899)[name = tensor("disambig_ev_cell_1")]; tensor var_2902 = mul(x = disambig_ev_cell_1, y = one_candidate_disambig)[name = tensor("op_2902")]; tensor disambig_ev_cell_3 = mul(x = var_2902, y = var_184_promoted)[name = tensor("disambig_ev_cell_3")]; tensor var_2904_promoted = const()[name = tensor("op_2904_promoted"), val = tensor(0x0p+0)]; tensor var_2905 = mul(x = disambig_ev_cell_3, y = var_2904_promoted)[name = tensor("op_2905")]; tensor var_2910_axes_0 = const()[name = tensor("op_2910_axes_0"), val = tensor([1])]; tensor var_2910_keep_dims_0 = const()[name = tensor("op_2910_keep_dims_0"), val = tensor(false)]; tensor var_2910 = reduce_sum(axes = var_2910_axes_0, keep_dims = var_2910_keep_dims_0, x = var_2905)[name = tensor("op_2910")]; tensor var_2912_promoted = const()[name = tensor("op_2912_promoted"), val = tensor(0x1p+0)]; tensor ditch_first = add(x = var_2910, y = var_2912_promoted)[name = tensor("ditch_first")]; tensor reshape_46 = const()[name = tensor("reshape_46"), val = tensor([0])]; tensor reshape_47 = const()[name = tensor("reshape_47"), val = tensor([0x0p+0])]; tensor scatter_9_mode_0 = const()[name = tensor("scatter_9_mode_0"), val = tensor("update")]; tensor scatter_9_axis_0 = const()[name = tensor("scatter_9_axis_0"), val = tensor(0)]; tensor scatter_9 = scatter(axis = scatter_9_axis_0, data = ditch_first, indices = reshape_46, mode = scatter_9_mode_0, updates = reshape_47)[name = tensor("scatter_9")]; tensor reshape_61_shape_0 = const()[name = tensor("reshape_61_shape_0"), val = tensor([50, 1])]; tensor reshape_61 = reshape(shape = reshape_61_shape_0, x = scatter_9)[name = tensor("reshape_61")]; tensor var_2923 = mul(x = disambig_ev_cell_3, y = reshape_61)[name = tensor("op_2923")]; tensor reduce_max_8_axes_0 = const()[name = tensor("reduce_max_8_axes_0"), val = tensor([0])]; tensor reduce_max_8_keep_dims_0 = const()[name = tensor("reduce_max_8_keep_dims_0"), val = tensor(false)]; tensor reduce_max_8 = reduce_max(axes = reduce_max_8_axes_0, keep_dims = reduce_max_8_keep_dims_0, x = var_2923)[name = tensor("reduce_max_8")]; tensor var_2932 = sub(x = var_2923, y = reduce_max_8)[name = tensor("op_2932")]; tensor max_diff = abs(x = var_2932)[name = tensor("max_diff")]; tensor reduce_min_2_axes_0 = const()[name = tensor("reduce_min_2_axes_0"), val = tensor([0])]; tensor reduce_min_2_keep_dims_0 = const()[name = tensor("reduce_min_2_keep_dims_0"), val = tensor(false)]; tensor reduce_min_2 = reduce_min(axes = reduce_min_2_axes_0, keep_dims = reduce_min_2_keep_dims_0, x = max_diff)[name = tensor("reduce_min_2")]; tensor var_2942 = const()[name = tensor("op_2942"), val = tensor(0x1.47ae14p-7)]; tensor var_2943 = add(x = reduce_min_2, y = var_2942)[name = tensor("op_2943")]; tensor var_2944 = less(x = max_diff, y = var_2943)[name = tensor("op_2944")]; tensor var_2944_promoted_dtype_0 = const()[name = tensor("op_2944_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2944_promoted = cast(dtype = var_2944_promoted_dtype_0, x = var_2944)[name = tensor("cast_121")]; tensor var_2948 = mul(x = var_2944_promoted, y = cdf)[name = tensor("op_2948")]; tensor cdf_threshold_axes_0 = const()[name = tensor("cdf_threshold_axes_0"), val = tensor([0])]; tensor cdf_threshold_keep_dims_0 = const()[name = tensor("cdf_threshold_keep_dims_0"), val = tensor(false)]; tensor cdf_threshold = reduce_sum(axes = cdf_threshold_axes_0, keep_dims = cdf_threshold_keep_dims_0, x = var_2948)[name = tensor("cdf_threshold")]; tensor var_2955 = const()[name = tensor("op_2955"), val = tensor(0x1.47ae14p-7)]; tensor var_2956 = add(x = cdf_threshold, y = var_2955)[name = tensor("op_2956")]; tensor var_2957 = less_equal(x = cdf, y = var_2956)[name = tensor("op_2957")]; tensor var_2957_promoted_dtype_0 = const()[name = tensor("op_2957_promoted_dtype_0"), val = tensor("int32")]; tensor var_2960 = const()[name = tensor("op_2960"), val = tensor(0)]; tensor logical_not_8 = const()[name = tensor("logical_not_8"), val = tensor(true)]; tensor var_2962 = argsort(ascending = logical_not_8, axis = var_2960, x = j)[name = tensor("op_2962")]; tensor var_2963 = const()[name = tensor("op_2963"), val = tensor(0)]; tensor var_2957_promoted = cast(dtype = var_2957_promoted_dtype_0, x = var_2957)[name = tensor("cast_120")]; tensor var_2965 = gather_along_axis(axis = var_2963, indices = var_2962, x = var_2957_promoted)[name = tensor("op_2965")]; tensor var_2965_promoted_dtype_0 = const()[name = tensor("op_2965_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2973_axes_0 = const()[name = tensor("op_2973_axes_0"), val = tensor([0])]; tensor var_2973_keep_dims_0 = const()[name = tensor("op_2973_keep_dims_0"), val = tensor(false)]; tensor var_2965_promoted = cast(dtype = var_2965_promoted_dtype_0, x = var_2965)[name = tensor("cast_119")]; tensor var_2973 = reduce_sum(axes = var_2973_axes_0, keep_dims = var_2973_keep_dims_0, x = var_2965_promoted)[name = tensor("op_2973")]; tensor var_2974_promoted = const()[name = tensor("op_2974_promoted"), val = tensor(0x1p+0)]; tensor var_2975 = greater(x = var_2973, y = var_2974_promoted)[name = tensor("op_2975")]; tensor var_2975_promoted_dtype_0 = const()[name = tensor("op_2975_promoted_dtype_0"), val = tensor("fp32")]; tensor var_2975_promoted = cast(dtype = var_2975_promoted_dtype_0, x = var_2975)[name = tensor("cast_118")]; tensor dis_3 = mul(x = var_2965_promoted, y = var_2975_promoted)[name = tensor("dis_3")]; tensor var_2984_promoted = const()[name = tensor("op_2984_promoted"), val = tensor(0x1p+0)]; tensor var_2986 = sub(x = var_2984_promoted, y = already_prompted)[name = tensor("op_2986")]; tensor var_2987_promoted = const()[name = tensor("op_2987_promoted"), val = tensor(0x1p+0)]; tensor var_2989 = sub(x = var_2987_promoted, y = is_resolved)[name = tensor("op_2989")]; tensor var_2990 = mul(x = var_2986, y = var_2989)[name = tensor("op_2990")]; tensor shadow_per_column = mul(x = var_2990, y = reduce_max_8)[name = tensor("shadow_per_column")]; tensor var_2995 = const()[name = tensor("op_2995"), val = tensor(0x1.388p+13)]; tensor var_2997 = sub(x = var_2995, y = reduce_max_8)[name = tensor("op_2997")]; tensor i_7 = add(x = reduce_max_8, y = random_seed)[name = tensor("i_7")]; tensor var_3000 = const()[name = tensor("op_3000"), val = tensor(0x1.eb851ep-6)]; tensor var_3001_div = floor_div(x = i_7, y = var_3000)[name = tensor("op_3001_div")]; tensor var_3001_div_scaled = mul(x = var_3001_div, y = var_3000)[name = tensor("op_3001_div_scaled")]; tensor var_3001 = sub(x = i_7, y = var_3001_div_scaled)[name = tensor("op_3001")]; tensor var_3002_promoted = const()[name = tensor("op_3002_promoted"), val = tensor(0x1.9p+6)]; tensor out_17 = mul(x = var_3001, y = var_3002_promoted)[name = tensor("out_17")]; tensor var_3004 = const()[name = tensor("op_3004"), val = tensor(0x1.1eb852p-4)]; tensor var_3005_div = floor_div(x = out_17, y = var_3004)[name = tensor("op_3005_div")]; tensor var_3005_div_scaled = mul(x = var_3005_div, y = var_3004)[name = tensor("op_3005_div_scaled")]; tensor var_3005 = sub(x = out_17, y = var_3005_div_scaled)[name = tensor("op_3005")]; tensor var_3006_promoted = const()[name = tensor("op_3006_promoted"), val = tensor(0x1.9p+6)]; tensor out_19 = mul(x = var_3005, y = var_3006_promoted)[name = tensor("out_19")]; tensor var_3008 = const()[name = tensor("op_3008"), val = tensor(0x1.851eb8p-3)]; tensor var_3009_div = floor_div(x = out_19, y = var_3008)[name = tensor("op_3009_div")]; tensor var_3009_div_scaled = mul(x = var_3009_div, y = var_3008)[name = tensor("op_3009_div_scaled")]; tensor var_3009 = sub(x = out_19, y = var_3009_div_scaled)[name = tensor("op_3009")]; tensor var_3010_promoted = const()[name = tensor("op_3010_promoted"), val = tensor(0x1.9p+6)]; tensor out = mul(x = var_3009, y = var_3010_promoted)[name = tensor("out")]; tensor var_3012 = const()[name = tensor("op_3012"), val = tensor(0x1.b645a2p-4)]; tensor var_3013_div = floor_div(x = out, y = var_3012)[name = tensor("op_3013_div")]; tensor var_3013_div_scaled = mul(x = var_3013_div, y = var_3012)[name = tensor("op_3013_div_scaled")]; tensor var_3013 = sub(x = out, y = var_3013_div_scaled)[name = tensor("op_3013")]; tensor var_3014_promoted = const()[name = tensor("op_3014_promoted"), val = tensor(0x1.f4p+9)]; tensor var_3015 = mul(x = var_3013, y = var_3014_promoted)[name = tensor("op_3015")]; tensor _inversed_3017_y_0 = const()[name = tensor("_inversed_3017_y_0"), val = tensor(0x1.323e34p-7)]; tensor _inversed_3017 = mul(x = var_3015, y = _inversed_3017_y_0)[name = tensor("_inversed_3017")]; tensor var_3018_promoted = const()[name = tensor("op_3018_promoted"), val = tensor(0x1.4p+3)]; tensor var_3019 = mul(x = _inversed_3017, y = var_3018_promoted)[name = tensor("op_3019")]; tensor to_force = add(x = var_2997, y = var_3019)[name = tensor("to_force")]; tensor var_3022 = mul(x = var_121_promoted, y = to_force)[name = tensor("op_3022")]; tensor var_3023_promoted = const()[name = tensor("op_3023_promoted"), val = tensor(0x0p+0)]; tensor var_3024 = greater(x = reduce_max_8, y = var_3023_promoted)[name = tensor("op_3024")]; tensor var_3024_promoted_dtype_0 = const()[name = tensor("op_3024_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3024_promoted = cast(dtype = var_3024_promoted_dtype_0, x = var_3024)[name = tensor("cast_117")]; tensor var_3025 = mul(x = var_3022, y = var_3024_promoted)[name = tensor("op_3025")]; tensor dis_ev_per_column_3 = add(x = reduce_max_8, y = var_3025)[name = tensor("dis_ev_per_column_3")]; tensor dis_ev_per_column = mul(x = dis_ev_per_column_3, y = var_2990)[name = tensor("dis_ev_per_column")]; tensor dis_ev_keep_dims_0 = const()[name = tensor("dis_ev_keep_dims_0"), val = tensor(false)]; tensor dis_ev = reduce_max(keep_dims = dis_ev_keep_dims_0, x = dis_ev_per_column)[name = tensor("dis_ev")]; tensor shadow_ev_keep_dims_0 = const()[name = tensor("shadow_ev_keep_dims_0"), val = tensor(false)]; tensor shadow_ev = reduce_max(keep_dims = shadow_ev_keep_dims_0, x = shadow_per_column)[name = tensor("shadow_ev")]; tensor var_3039 = sub(x = dis_ev_per_column, y = dis_ev)[name = tensor("op_3039")]; tensor var_3040 = abs(x = var_3039)[name = tensor("op_3040")]; tensor var_3041_promoted = const()[name = tensor("op_3041_promoted"), val = tensor(-0x1p+0)]; tensor var_3042 = mul(x = var_3040, y = var_3041_promoted)[name = tensor("op_3042")]; tensor var_3048_keep_dims_0 = const()[name = tensor("op_3048_keep_dims_0"), val = tensor(false)]; tensor var_3048_axis_0 = const()[name = tensor("op_3048_axis_0"), val = tensor(-1)]; tensor var_3048 = reduce_argmax(axis = var_3048_axis_0, keep_dims = var_3048_keep_dims_0, x = var_3042)[name = tensor("op_3048")]; tensor var_3050 = const()[name = tensor("op_3050"), val = tensor(1)]; tensor dis_index = add(x = var_3048, y = var_3050)[name = tensor("dis_index")]; tensor var_3053 = sub(x = shadow_per_column, y = shadow_ev)[name = tensor("op_3053")]; tensor var_3054 = abs(x = var_3053)[name = tensor("op_3054")]; tensor var_3055_promoted = const()[name = tensor("op_3055_promoted"), val = tensor(-0x1p+0)]; tensor var_3056 = mul(x = var_3054, y = var_3055_promoted)[name = tensor("op_3056")]; tensor var_3062_keep_dims_0 = const()[name = tensor("op_3062_keep_dims_0"), val = tensor(false)]; tensor var_3062_axis_0 = const()[name = tensor("op_3062_axis_0"), val = tensor(-1)]; tensor var_3062 = reduce_argmax(axis = var_3062_axis_0, keep_dims = var_3062_keep_dims_0, x = var_3056)[name = tensor("op_3062")]; tensor var_3064 = const()[name = tensor("op_3064"), val = tensor(1)]; tensor shadow_index = add(x = var_3062, y = var_3064)[name = tensor("shadow_index")]; tensor var_3066 = const()[name = tensor("op_3066"), val = tensor(-0x1.e848p+19)]; tensor var_3067 = greater(x = dis_3, y = var_3066)[name = tensor("op_3067")]; tensor var_3067_promoted_dtype_0 = const()[name = tensor("op_3067_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3071 = const()[name = tensor("op_3071"), val = tensor(1)]; tensor var_3073_exclusive_0 = const()[name = tensor("op_3073_exclusive_0"), val = tensor(false)]; tensor var_3073_reverse_0 = const()[name = tensor("op_3073_reverse_0"), val = tensor(false)]; tensor var_3067_promoted = cast(dtype = var_3067_promoted_dtype_0, x = var_3067)[name = tensor("cast_116")]; tensor var_3073 = cumsum(axis = var_3071, exclusive = var_3073_exclusive_0, reverse = var_3073_reverse_0, x = var_3067_promoted)[name = tensor("op_3073")]; tensor dis_index_promoted_dtype_0 = const()[name = tensor("dis_index_promoted_dtype_0"), val = tensor("fp32")]; tensor dis_index_promoted = cast(dtype = dis_index_promoted_dtype_0, x = dis_index)[name = tensor("cast_115")]; tensor var_3074 = equal(x = var_3073, y = dis_index_promoted)[name = tensor("op_3074")]; tensor var_3074_promoted_dtype_0 = const()[name = tensor("op_3074_promoted_dtype_0"), val = tensor("fp32")]; tensor shadow_index_promoted_dtype_0 = const()[name = tensor("shadow_index_promoted_dtype_0"), val = tensor("fp32")]; tensor shadow_index_promoted = cast(dtype = shadow_index_promoted_dtype_0, x = shadow_index)[name = tensor("cast_113")]; tensor var_3086 = equal(x = var_3073, y = shadow_index_promoted)[name = tensor("op_3086")]; tensor var_3086_promoted_dtype_0 = const()[name = tensor("op_3086_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3074_promoted = cast(dtype = var_3074_promoted_dtype_0, x = var_3074)[name = tensor("cast_114")]; tensor dis = mul(x = dis_3, y = var_3074_promoted)[name = tensor("dis")]; tensor var_3086_promoted = cast(dtype = var_3086_promoted_dtype_0, x = var_3086)[name = tensor("cast_112")]; tensor shadow_dis = mul(x = dis, y = var_3086_promoted)[name = tensor("shadow_dis")]; tensor dis_id = const()[name = tensor("dis_id"), val = tensor([4, 2])]; tensor dis_id_promoted = const()[name = tensor("dis_id_promoted"), val = tensor([0x1p+2, 0x1p+1])]; tensor reshape_51 = const()[name = tensor("reshape_51"), val = tensor([1])]; tensor reshape_52_shape_0 = const()[name = tensor("reshape_52_shape_0"), val = tensor([-1])]; tensor reshape_52 = reshape(shape = reshape_52_shape_0, x = dis_index)[name = tensor("reshape_52")]; tensor scatter_10_mode_0 = const()[name = tensor("scatter_10_mode_0"), val = tensor("update")]; tensor scatter_10_axis_0 = const()[name = tensor("scatter_10_axis_0"), val = tensor(0)]; tensor scatter_10 = scatter(axis = scatter_10_axis_0, data = dis_id, indices = reshape_51, mode = scatter_10_mode_0, updates = reshape_52)[name = tensor("scatter_10")]; tensor reshape_56 = const()[name = tensor("reshape_56"), val = tensor([1])]; tensor reshape_57_shape_0 = const()[name = tensor("reshape_57_shape_0"), val = tensor([-1])]; tensor reshape_57 = reshape(shape = reshape_57_shape_0, x = shadow_index_promoted)[name = tensor("reshape_57")]; tensor scatter_11_mode_0 = const()[name = tensor("scatter_11_mode_0"), val = tensor("update")]; tensor scatter_11_axis_0 = const()[name = tensor("scatter_11_axis_0"), val = tensor(0)]; tensor scatter_11 = scatter(axis = scatter_11_axis_0, data = dis_id_promoted, indices = reshape_56, mode = scatter_11_mode_0, updates = reshape_57)[name = tensor("scatter_11")]; tensor var_3106 = greater(x = dis_ev, y = action_confirm_ev)[name = tensor("op_3106")]; tensor var_3106_promoted_dtype_0 = const()[name = tensor("op_3106_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3110 = greater(x = dis_ev, y = param_confirm_ev)[name = tensor("op_3110")]; tensor var_3110_promoted_dtype_0 = const()[name = tensor("op_3110_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3110_promoted = cast(dtype = var_3110_promoted_dtype_0, x = var_3110)[name = tensor("cast_110")]; tensor var_3106_promoted = cast(dtype = var_3106_promoted_dtype_0, x = var_3106)[name = tensor("cast_111")]; tensor var_3114 = mul(x = var_3106_promoted, y = var_3110_promoted)[name = tensor("op_3114")]; tensor var_3115 = greater(x = dis_ev, y = de_ev)[name = tensor("op_3115")]; tensor var_3115_promoted_dtype_0 = const()[name = tensor("op_3115_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3115_promoted = cast(dtype = var_3115_promoted_dtype_0, x = var_3115)[name = tensor("cast_109")]; tensor should_disambig = mul(x = var_3114, y = var_3115_promoted)[name = tensor("should_disambig")]; tensor var_3120 = greater_equal(x = action_confirm_ev, y = dis_ev)[name = tensor("op_3120")]; tensor var_3120_promoted_dtype_0 = const()[name = tensor("op_3120_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3124 = greater(x = action_confirm_ev, y = param_confirm_ev)[name = tensor("op_3124")]; tensor var_3124_promoted_dtype_0 = const()[name = tensor("op_3124_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3124_promoted = cast(dtype = var_3124_promoted_dtype_0, x = var_3124)[name = tensor("cast_107")]; tensor var_3120_promoted = cast(dtype = var_3120_promoted_dtype_0, x = var_3120)[name = tensor("cast_108")]; tensor var_3128 = mul(x = var_3120_promoted, y = var_3124_promoted)[name = tensor("op_3128")]; tensor var_3129 = greater(x = action_confirm_ev, y = de_ev)[name = tensor("op_3129")]; tensor var_3129_promoted_dtype_0 = const()[name = tensor("op_3129_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3129_promoted = cast(dtype = var_3129_promoted_dtype_0, x = var_3129)[name = tensor("cast_106")]; tensor should_action_confirm = mul(x = var_3128, y = var_3129_promoted)[name = tensor("should_action_confirm")]; tensor var_3134 = greater_equal(x = param_confirm_ev, y = dis_ev)[name = tensor("op_3134")]; tensor var_3134_promoted_dtype_0 = const()[name = tensor("op_3134_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3138 = greater(x = param_confirm_ev, y = action_confirm_ev)[name = tensor("op_3138")]; tensor var_3138_promoted_dtype_0 = const()[name = tensor("op_3138_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3138_promoted = cast(dtype = var_3138_promoted_dtype_0, x = var_3138)[name = tensor("cast_104")]; tensor var_3134_promoted = cast(dtype = var_3134_promoted_dtype_0, x = var_3134)[name = tensor("cast_105")]; tensor var_3142 = mul(x = var_3134_promoted, y = var_3138_promoted)[name = tensor("op_3142")]; tensor var_3143 = greater(x = param_confirm_ev, y = de_ev)[name = tensor("op_3143")]; tensor var_3143_promoted_dtype_0 = const()[name = tensor("op_3143_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3143_promoted = cast(dtype = var_3143_promoted_dtype_0, x = var_3143)[name = tensor("cast_103")]; tensor should_param_confirm = mul(x = var_3142, y = var_3143_promoted)[name = tensor("should_param_confirm")]; tensor var_3148 = greater_equal(x = de_ev, y = dis_ev)[name = tensor("op_3148")]; tensor var_3148_promoted_dtype_0 = const()[name = tensor("op_3148_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3152 = greater_equal(x = de_ev, y = action_confirm_ev)[name = tensor("op_3152")]; tensor var_3152_promoted_dtype_0 = const()[name = tensor("op_3152_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3152_promoted = cast(dtype = var_3152_promoted_dtype_0, x = var_3152)[name = tensor("cast_101")]; tensor var_3148_promoted = cast(dtype = var_3148_promoted_dtype_0, x = var_3148)[name = tensor("cast_102")]; tensor var_3156 = mul(x = var_3148_promoted, y = var_3152_promoted)[name = tensor("op_3156")]; tensor var_3157 = greater_equal(x = de_ev, y = param_confirm_ev)[name = tensor("op_3157")]; tensor var_3157_promoted_dtype_0 = const()[name = tensor("op_3157_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3157_promoted = cast(dtype = var_3157_promoted_dtype_0, x = var_3157)[name = tensor("cast_100")]; tensor should_de = mul(x = var_3156, y = var_3157_promoted)[name = tensor("should_de")]; tensor var_3162 = mul(x = should_de, y = scatter_6)[name = tensor("op_3162")]; tensor var_3163 = mul(x = should_action_confirm, y = scatter_7)[name = tensor("op_3163")]; tensor var_3165 = add(x = var_3162, y = var_3163)[name = tensor("op_3165")]; tensor param_confirm_id_internal_tensor_assign_1_promoted_dtype_0 = const()[name = tensor("param_confirm_id_internal_tensor_assign_1_promoted_dtype_0"), val = tensor("fp32")]; tensor param_confirm_id_internal_tensor_assign_1_promoted = cast(dtype = param_confirm_id_internal_tensor_assign_1_promoted_dtype_0, x = scatter_8)[name = tensor("cast_99")]; tensor var_3166 = mul(x = should_param_confirm, y = param_confirm_id_internal_tensor_assign_1_promoted)[name = tensor("op_3166")]; tensor var_3168 = add(x = var_3165, y = var_3166)[name = tensor("op_3168")]; tensor dis_id_internal_tensor_assign_1_promoted_dtype_0 = const()[name = tensor("dis_id_internal_tensor_assign_1_promoted_dtype_0"), val = tensor("fp32")]; tensor dis_id_internal_tensor_assign_1_promoted = cast(dtype = dis_id_internal_tensor_assign_1_promoted_dtype_0, x = scatter_10)[name = tensor("cast_98")]; tensor var_3169 = mul(x = should_disambig, y = dis_id_internal_tensor_assign_1_promoted)[name = tensor("op_3169")]; tensor actionId = add(x = var_3168, y = var_3169)[name = tensor("op_3171")]; tensor var_3172 = mul(x = should_de, y = de)[name = tensor("op_3172")]; tensor var_3173 = mul(x = should_action_confirm, y = ac)[name = tensor("op_3173")]; tensor var_3175 = add(x = var_3172, y = var_3173)[name = tensor("op_3175")]; tensor var_3176 = mul(x = should_param_confirm, y = param_confirm)[name = tensor("op_3176")]; tensor var_3178 = add(x = var_3175, y = var_3176)[name = tensor("op_3178")]; tensor var_3179 = mul(x = should_disambig, y = dis)[name = tensor("op_3179")]; tensor action_mask = add(x = var_3178, y = var_3179)[name = tensor("action_mask")]; tensor var_3182 = greater(x = shadow_ev, y = ac_ev)[name = tensor("op_3182")]; tensor var_3182_promoted_dtype_0 = const()[name = tensor("op_3182_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3186 = greater(x = shadow_ev, y = param_confirm_ev)[name = tensor("op_3186")]; tensor var_3186_promoted_dtype_0 = const()[name = tensor("op_3186_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3186_promoted = cast(dtype = var_3186_promoted_dtype_0, x = var_3186)[name = tensor("cast_96")]; tensor var_3182_promoted = cast(dtype = var_3182_promoted_dtype_0, x = var_3182)[name = tensor("cast_97")]; tensor var_3190 = mul(x = var_3182_promoted, y = var_3186_promoted)[name = tensor("op_3190")]; tensor var_3191 = greater(x = shadow_ev, y = de_ev)[name = tensor("op_3191")]; tensor var_3191_promoted_dtype_0 = const()[name = tensor("op_3191_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3191_promoted = cast(dtype = var_3191_promoted_dtype_0, x = var_3191)[name = tensor("cast_95")]; tensor shad_should_disambig = mul(x = var_3190, y = var_3191_promoted)[name = tensor("shad_should_disambig")]; tensor var_3196 = greater_equal(x = ac_ev, y = shadow_ev)[name = tensor("op_3196")]; tensor var_3196_promoted_dtype_0 = const()[name = tensor("op_3196_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3200 = greater(x = ac_ev, y = param_confirm_ev)[name = tensor("op_3200")]; tensor var_3200_promoted_dtype_0 = const()[name = tensor("op_3200_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3200_promoted = cast(dtype = var_3200_promoted_dtype_0, x = var_3200)[name = tensor("cast_93")]; tensor var_3196_promoted = cast(dtype = var_3196_promoted_dtype_0, x = var_3196)[name = tensor("cast_94")]; tensor var_3204 = mul(x = var_3196_promoted, y = var_3200_promoted)[name = tensor("op_3204")]; tensor var_3205 = greater(x = ac_ev, y = de_ev)[name = tensor("op_3205")]; tensor var_3205_promoted_dtype_0 = const()[name = tensor("op_3205_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3205_promoted = cast(dtype = var_3205_promoted_dtype_0, x = var_3205)[name = tensor("cast_92")]; tensor shad_should_action_confirm = mul(x = var_3204, y = var_3205_promoted)[name = tensor("shad_should_action_confirm")]; tensor var_3210 = greater_equal(x = param_confirm_ev, y = shadow_ev)[name = tensor("op_3210")]; tensor var_3210_promoted_dtype_0 = const()[name = tensor("op_3210_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3214 = greater(x = param_confirm_ev, y = ac_ev)[name = tensor("op_3214")]; tensor var_3214_promoted_dtype_0 = const()[name = tensor("op_3214_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3214_promoted = cast(dtype = var_3214_promoted_dtype_0, x = var_3214)[name = tensor("cast_90")]; tensor var_3210_promoted = cast(dtype = var_3210_promoted_dtype_0, x = var_3210)[name = tensor("cast_91")]; tensor var_3218 = mul(x = var_3210_promoted, y = var_3214_promoted)[name = tensor("op_3218")]; tensor shad_should_param_confirm = mul(x = var_3218, y = var_3143_promoted)[name = tensor("shad_should_param_confirm")]; tensor var_3224 = greater_equal(x = de_ev, y = shadow_ev)[name = tensor("op_3224")]; tensor var_3224_promoted_dtype_0 = const()[name = tensor("op_3224_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3228 = greater_equal(x = de_ev, y = ac_ev)[name = tensor("op_3228")]; tensor var_3228_promoted_dtype_0 = const()[name = tensor("op_3228_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3228_promoted = cast(dtype = var_3228_promoted_dtype_0, x = var_3228)[name = tensor("cast_88")]; tensor var_3224_promoted = cast(dtype = var_3224_promoted_dtype_0, x = var_3224)[name = tensor("cast_89")]; tensor var_3232 = mul(x = var_3224_promoted, y = var_3228_promoted)[name = tensor("op_3232")]; tensor shad_should_de = mul(x = var_3232, y = var_3157_promoted)[name = tensor("shad_should_de")]; tensor var_3238 = mul(x = shad_should_de, y = scatter_6)[name = tensor("op_3238")]; tensor var_3239 = mul(x = shad_should_action_confirm, y = scatter_7)[name = tensor("op_3239")]; tensor var_3241 = add(x = var_3238, y = var_3239)[name = tensor("op_3241")]; tensor var_3242 = mul(x = shad_should_param_confirm, y = param_confirm_id_internal_tensor_assign_1_promoted)[name = tensor("op_3242")]; tensor var_3244 = add(x = var_3241, y = var_3242)[name = tensor("op_3244")]; tensor var_3245 = mul(x = shad_should_disambig, y = scatter_11)[name = tensor("op_3245")]; tensor shadowActionId = add(x = var_3244, y = var_3245)[name = tensor("op_3247")]; tensor var_3248 = mul(x = shad_should_de, y = de)[name = tensor("op_3248")]; tensor var_3249 = mul(x = shad_should_action_confirm, y = ac)[name = tensor("op_3249")]; tensor var_3251 = add(x = var_3248, y = var_3249)[name = tensor("op_3251")]; tensor var_3252 = mul(x = shad_should_param_confirm, y = param_confirm)[name = tensor("op_3252")]; tensor var_3254 = add(x = var_3251, y = var_3252)[name = tensor("op_3254")]; tensor var_3255 = mul(x = shad_should_disambig, y = shadow_dis)[name = tensor("op_3255")]; tensor shad_action_mask = add(x = var_3254, y = var_3255)[name = tensor("shad_action_mask")]; tensor var_3258_promoted = const()[name = tensor("op_3258_promoted"), val = tensor(0x0p+0)]; tensor var_3259 = greater(x = action_mask, y = var_3258_promoted)[name = tensor("op_3259")]; tensor var_3259_promoted_dtype_0 = const()[name = tensor("op_3259_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3259_promoted = cast(dtype = var_3259_promoted_dtype_0, x = var_3259)[name = tensor("cast_87")]; tensor x_11 = mul(x = x_3, y = var_3259_promoted)[name = tensor("x_11")]; tensor var_3266_begin_0 = const()[name = tensor("op_3266_begin_0"), val = tensor([0, 0])]; tensor var_3266_end_0 = const()[name = tensor("op_3266_end_0"), val = tensor([1, 15])]; tensor var_3266_end_mask_0 = const()[name = tensor("op_3266_end_mask_0"), val = tensor([false, true])]; tensor var_3266_squeeze_mask_0 = const()[name = tensor("op_3266_squeeze_mask_0"), val = tensor([true, false])]; tensor var_3266 = slice_by_index(begin = var_3266_begin_0, end = var_3266_end_0, end_mask = var_3266_end_mask_0, squeeze_mask = var_3266_squeeze_mask_0, x = x_11)[name = tensor("op_3266")]; tensor var_3267 = const()[name = tensor("op_3267"), val = tensor(-0x1.e848p+19)]; tensor var_3268 = greater(x = var_3266, y = var_3267)[name = tensor("op_3268")]; tensor var_3268_promoted_dtype_0 = const()[name = tensor("op_3268_promoted_dtype_0"), val = tensor("int32")]; tensor var_3271 = const()[name = tensor("op_3271"), val = tensor(-0x1.e848p+19)]; tensor ones_9_promoted_dtype_0 = const()[name = tensor("ones_9_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3268_to_fp32 = cast(dtype = ones_9_promoted_dtype_0, x = var_3268)[name = tensor("cast_85")]; tensor small_7 = mul(x = var_3268_to_fp32, y = var_3271)[name = tensor("small_7")]; tensor var_3273 = const()[name = tensor("op_3273"), val = tensor(0x1.e848p+19)]; tensor big_7 = mul(x = var_3268_to_fp32, y = var_3273)[name = tensor("big_7")]; tensor var_3275 = const()[name = tensor("op_3275"), val = tensor(0)]; tensor var_3268_promoted = cast(dtype = var_3268_promoted_dtype_0, x = var_3268)[name = tensor("cast_86")]; tensor zeros_7 = mul(x = var_3268_promoted, y = var_3275)[name = tensor("zeros_7")]; tensor var_3278_axes_0 = const()[name = tensor("op_3278_axes_0"), val = tensor([0])]; tensor var_3278 = expand_dims(axes = var_3278_axes_0, x = small_7)[name = tensor("op_3278")]; tensor var_3280_axes_0 = const()[name = tensor("op_3280_axes_0"), val = tensor([0])]; tensor var_3280 = expand_dims(axes = var_3280_axes_0, x = big_7)[name = tensor("op_3280")]; tensor var_3282 = const()[name = tensor("op_3282"), val = tensor(0)]; tensor x_padded_7_interleave_0 = const()[name = tensor("x_padded_7_interleave_0"), val = tensor(false)]; tensor x_padded_7 = concat(axis = var_3282, interleave = x_padded_7_interleave_0, values = (var_3278, x_11, var_3280))[name = tensor("x_padded_7")]; tensor var_3284 = const()[name = tensor("op_3284"), val = tensor(0)]; tensor logical_not_9 = const()[name = tensor("logical_not_9"), val = tensor(true)]; tensor i_9 = argsort(ascending = logical_not_9, axis = var_3284, x = x_padded_7)[name = tensor("i_9")]; tensor by_x_7 = gather_along_axis(axis = var_3284, indices = i_9, x = x_padded_7)[name = tensor("by_x_7")]; tensor var_3292_begin_0 = const()[name = tensor("op_3292_begin_0"), val = tensor([1, 0])]; tensor var_3292_end_0 = const()[name = tensor("op_3292_end_0"), val = tensor([51, 15])]; tensor var_3292_end_mask_0 = const()[name = tensor("op_3292_end_mask_0"), val = tensor([false, true])]; tensor var_3292 = slice_by_index(begin = var_3292_begin_0, end = var_3292_end_0, end_mask = var_3292_end_mask_0, x = by_x_7)[name = tensor("op_3292")]; tensor var_3297_begin_0 = const()[name = tensor("op_3297_begin_0"), val = tensor([0, 0])]; tensor var_3297_end_0 = const()[name = tensor("op_3297_end_0"), val = tensor([50, 15])]; tensor var_3297_end_mask_0 = const()[name = tensor("op_3297_end_mask_0"), val = tensor([false, true])]; tensor var_3297 = slice_by_index(begin = var_3297_begin_0, end = var_3297_end_0, end_mask = var_3297_end_mask_0, x = by_x_7)[name = tensor("op_3297")]; tensor var_3299 = sub(x = var_3292, y = var_3297)[name = tensor("op_3299")]; tensor var_3300_promoted = const()[name = tensor("op_3300_promoted"), val = tensor(0x0p+0)]; tensor var_3301 = greater(x = var_3299, y = var_3300_promoted)[name = tensor("op_3301")]; tensor var_3301_promoted_dtype_0 = const()[name = tensor("op_3301_promoted_dtype_0"), val = tensor("int32")]; tensor var_3305_axes_0 = const()[name = tensor("op_3305_axes_0"), val = tensor([0])]; tensor var_3305 = expand_dims(axes = var_3305_axes_0, x = zeros_7)[name = tensor("op_3305")]; tensor var_3309 = const()[name = tensor("op_3309"), val = tensor(0)]; tensor mask_21_interleave_0 = const()[name = tensor("mask_21_interleave_0"), val = tensor(false)]; tensor var_3301_promoted = cast(dtype = var_3301_promoted_dtype_0, x = var_3301)[name = tensor("cast_84")]; tensor mask_21 = concat(axis = var_3309, interleave = mask_21_interleave_0, values = (var_3305, var_3301_promoted, var_3305))[name = tensor("mask_21")]; tensor mask_21_promoted_dtype_0 = const()[name = tensor("mask_21_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_21_promoted = cast(dtype = mask_21_promoted_dtype_0, x = mask_21)[name = tensor("cast_83")]; tensor var_3311 = mul(x = by_x_7, y = mask_21_promoted)[name = tensor("op_3311")]; tensor var_3312 = const()[name = tensor("op_3312"), val = tensor(0)]; tensor logical_not_10 = const()[name = tensor("logical_not_10"), val = tensor(true)]; tensor var_3314 = argsort(ascending = logical_not_10, axis = var_3312, x = i_9)[name = tensor("op_3314")]; tensor var_3315 = const()[name = tensor("op_3315"), val = tensor(0)]; tensor unique_7 = gather_along_axis(axis = var_3315, indices = var_3314, x = var_3311)[name = tensor("unique_7")]; tensor relevant_candidates_1_begin_0 = const()[name = tensor("relevant_candidates_1_begin_0"), val = tensor([1, 0])]; tensor relevant_candidates_1_end_0 = const()[name = tensor("relevant_candidates_1_end_0"), val = tensor([51, 15])]; tensor relevant_candidates_1_end_mask_0 = const()[name = tensor("relevant_candidates_1_end_mask_0"), val = tensor([false, true])]; tensor relevant_candidates_1 = slice_by_index(begin = relevant_candidates_1_begin_0, end = relevant_candidates_1_end_0, end_mask = relevant_candidates_1_end_mask_0, x = unique_7)[name = tensor("relevant_candidates_1")]; tensor concat_26 = const()[name = tensor("concat_26"), val = tensor([750])]; tensor act_candidates_1 = reshape(shape = concat_26, x = relevant_candidates_1)[name = tensor("act_candidates_1")]; tensor var_3331 = mul(x = dirichlet_mu_1, y = var_3259_promoted)[name = tensor("op_3331")]; tensor concat_27 = const()[name = tensor("concat_27"), val = tensor([750])]; tensor var_3334 = reshape(shape = concat_27, x = var_3331)[name = tensor("op_3334")]; tensor var_3335 = const()[name = tensor("op_3335"), val = tensor(0)]; tensor logical_not_11 = const()[name = tensor("logical_not_11"), val = tensor(false)]; tensor indx_1 = argsort(ascending = logical_not_11, axis = var_3335, x = var_3334)[name = tensor("indx_1")]; tensor act_candidates_axis_0 = const()[name = tensor("act_candidates_axis_0"), val = tensor(0)]; tensor act_candidates = gather(axis = act_candidates_axis_0, indices = indx_1, x = act_candidates_1)[name = tensor("act_candidates")]; tensor var_3343_promoted = const()[name = tensor("op_3343_promoted"), val = tensor(0x0p+0)]; tensor var_3344 = greater(x = act_candidates, y = var_3343_promoted)[name = tensor("op_3344")]; tensor non_zero_6 = non_zero(x = var_3344)[name = tensor("non_zero_6")]; tensor gather_nd_0 = gather_nd(indices = non_zero_6, x = act_candidates)[name = tensor("gather_nd_0")]; tensor var_3346_perm_0 = const()[name = tensor("op_3346_perm_0"), val = tensor([0])]; tensor var_3347_promoted = const()[name = tensor("op_3347_promoted"), val = tensor(0x0p+0)]; tensor var_3348 = greater(x = shad_action_mask, y = var_3347_promoted)[name = tensor("op_3348")]; tensor var_3348_promoted_dtype_0 = const()[name = tensor("op_3348_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3348_promoted = cast(dtype = var_3348_promoted_dtype_0, x = var_3348)[name = tensor("cast_82")]; tensor x = mul(x = x_3, y = var_3348_promoted)[name = tensor("x")]; tensor var_3355_begin_0 = const()[name = tensor("op_3355_begin_0"), val = tensor([0, 0])]; tensor var_3355_end_0 = const()[name = tensor("op_3355_end_0"), val = tensor([1, 15])]; tensor var_3355_end_mask_0 = const()[name = tensor("op_3355_end_mask_0"), val = tensor([false, true])]; tensor var_3355_squeeze_mask_0 = const()[name = tensor("op_3355_squeeze_mask_0"), val = tensor([true, false])]; tensor var_3355 = slice_by_index(begin = var_3355_begin_0, end = var_3355_end_0, end_mask = var_3355_end_mask_0, squeeze_mask = var_3355_squeeze_mask_0, x = x)[name = tensor("op_3355")]; tensor var_3356 = const()[name = tensor("op_3356"), val = tensor(-0x1.e848p+19)]; tensor var_3357 = greater(x = var_3355, y = var_3356)[name = tensor("op_3357")]; tensor var_3357_promoted_dtype_0 = const()[name = tensor("op_3357_promoted_dtype_0"), val = tensor("int32")]; tensor var_3360 = const()[name = tensor("op_3360"), val = tensor(-0x1.e848p+19)]; tensor ones_promoted_dtype_0 = const()[name = tensor("ones_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3357_to_fp32 = cast(dtype = ones_promoted_dtype_0, x = var_3357)[name = tensor("cast_80")]; tensor small = mul(x = var_3357_to_fp32, y = var_3360)[name = tensor("small")]; tensor var_3362 = const()[name = tensor("op_3362"), val = tensor(0x1.e848p+19)]; tensor big = mul(x = var_3357_to_fp32, y = var_3362)[name = tensor("big")]; tensor var_3364 = const()[name = tensor("op_3364"), val = tensor(0)]; tensor var_3357_promoted = cast(dtype = var_3357_promoted_dtype_0, x = var_3357)[name = tensor("cast_81")]; tensor zeros = mul(x = var_3357_promoted, y = var_3364)[name = tensor("zeros")]; tensor var_3367_axes_0 = const()[name = tensor("op_3367_axes_0"), val = tensor([0])]; tensor var_3367 = expand_dims(axes = var_3367_axes_0, x = small)[name = tensor("op_3367")]; tensor var_3369_axes_0 = const()[name = tensor("op_3369_axes_0"), val = tensor([0])]; tensor var_3369 = expand_dims(axes = var_3369_axes_0, x = big)[name = tensor("op_3369")]; tensor var_3371 = const()[name = tensor("op_3371"), val = tensor(0)]; tensor x_padded_interleave_0 = const()[name = tensor("x_padded_interleave_0"), val = tensor(false)]; tensor x_padded = concat(axis = var_3371, interleave = x_padded_interleave_0, values = (var_3367, x, var_3369))[name = tensor("x_padded")]; tensor var_3373 = const()[name = tensor("op_3373"), val = tensor(0)]; tensor logical_not_12 = const()[name = tensor("logical_not_12"), val = tensor(true)]; tensor i = argsort(ascending = logical_not_12, axis = var_3373, x = x_padded)[name = tensor("i")]; tensor by_x = gather_along_axis(axis = var_3373, indices = i, x = x_padded)[name = tensor("by_x")]; tensor var_3381_begin_0 = const()[name = tensor("op_3381_begin_0"), val = tensor([1, 0])]; tensor var_3381_end_0 = const()[name = tensor("op_3381_end_0"), val = tensor([51, 15])]; tensor var_3381_end_mask_0 = const()[name = tensor("op_3381_end_mask_0"), val = tensor([false, true])]; tensor var_3381 = slice_by_index(begin = var_3381_begin_0, end = var_3381_end_0, end_mask = var_3381_end_mask_0, x = by_x)[name = tensor("op_3381")]; tensor var_3386_begin_0 = const()[name = tensor("op_3386_begin_0"), val = tensor([0, 0])]; tensor var_3386_end_0 = const()[name = tensor("op_3386_end_0"), val = tensor([50, 15])]; tensor var_3386_end_mask_0 = const()[name = tensor("op_3386_end_mask_0"), val = tensor([false, true])]; tensor var_3386 = slice_by_index(begin = var_3386_begin_0, end = var_3386_end_0, end_mask = var_3386_end_mask_0, x = by_x)[name = tensor("op_3386")]; tensor var_3388 = sub(x = var_3381, y = var_3386)[name = tensor("op_3388")]; tensor var_3389_promoted = const()[name = tensor("op_3389_promoted"), val = tensor(0x0p+0)]; tensor var_3390 = greater(x = var_3388, y = var_3389_promoted)[name = tensor("op_3390")]; tensor var_3390_promoted_dtype_0 = const()[name = tensor("op_3390_promoted_dtype_0"), val = tensor("int32")]; tensor var_3394_axes_0 = const()[name = tensor("op_3394_axes_0"), val = tensor([0])]; tensor var_3394 = expand_dims(axes = var_3394_axes_0, x = zeros)[name = tensor("op_3394")]; tensor var_3398 = const()[name = tensor("op_3398"), val = tensor(0)]; tensor mask_interleave_0 = const()[name = tensor("mask_interleave_0"), val = tensor(false)]; tensor var_3390_promoted = cast(dtype = var_3390_promoted_dtype_0, x = var_3390)[name = tensor("cast_79")]; tensor mask = concat(axis = var_3398, interleave = mask_interleave_0, values = (var_3394, var_3390_promoted, var_3394))[name = tensor("mask")]; tensor mask_promoted_dtype_0 = const()[name = tensor("mask_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_promoted = cast(dtype = mask_promoted_dtype_0, x = mask)[name = tensor("cast_78")]; tensor var_3400 = mul(x = by_x, y = mask_promoted)[name = tensor("op_3400")]; tensor var_3401 = const()[name = tensor("op_3401"), val = tensor(0)]; tensor logical_not_13 = const()[name = tensor("logical_not_13"), val = tensor(true)]; tensor var_3403 = argsort(ascending = logical_not_13, axis = var_3401, x = i)[name = tensor("op_3403")]; tensor var_3404 = const()[name = tensor("op_3404"), val = tensor(0)]; tensor unique = gather_along_axis(axis = var_3404, indices = var_3403, x = var_3400)[name = tensor("unique")]; tensor relevant_candidates_begin_0 = const()[name = tensor("relevant_candidates_begin_0"), val = tensor([1, 0])]; tensor relevant_candidates_end_0 = const()[name = tensor("relevant_candidates_end_0"), val = tensor([51, 15])]; tensor relevant_candidates_end_mask_0 = const()[name = tensor("relevant_candidates_end_mask_0"), val = tensor([false, true])]; tensor relevant_candidates = slice_by_index(begin = relevant_candidates_begin_0, end = relevant_candidates_end_0, end_mask = relevant_candidates_end_mask_0, x = unique)[name = tensor("relevant_candidates")]; tensor concat_29 = const()[name = tensor("concat_29"), val = tensor([750])]; tensor shadow_candidates_1 = reshape(shape = concat_29, x = relevant_candidates)[name = tensor("shadow_candidates_1")]; tensor shadow_candidates_axis_0 = const()[name = tensor("shadow_candidates_axis_0"), val = tensor(0)]; tensor shadow_candidates = gather(axis = shadow_candidates_axis_0, indices = indx_1, x = shadow_candidates_1)[name = tensor("shadow_candidates")]; tensor var_3432_promoted = const()[name = tensor("op_3432_promoted"), val = tensor(0x0p+0)]; tensor var_3433 = greater(x = shadow_candidates, y = var_3432_promoted)[name = tensor("op_3433")]; tensor non_zero_7 = non_zero(x = var_3433)[name = tensor("non_zero_7")]; tensor gather_nd_1 = gather_nd(indices = non_zero_7, x = shadow_candidates)[name = tensor("gather_nd_1")]; tensor var_3435_perm_0 = const()[name = tensor("op_3435_perm_0"), val = tensor([0])]; tensor var_3437_axes_0 = const()[name = tensor("op_3437_axes_0"), val = tensor([0])]; tensor var_3437 = expand_dims(axes = var_3437_axes_0, x = dirichlet_mu_1)[name = tensor("op_3437")]; tensor var_3441 = const()[name = tensor("op_3441"), val = tensor(0)]; tensor var_3442_interleave_0 = const()[name = tensor("op_3442_interleave_0"), val = tensor(false)]; tensor rankings = concat(axis = var_3441, interleave = var_3442_interleave_0, values = (var_3437, var_808))[name = tensor("op_3442")]; tensor var_3444 = const()[name = tensor("op_3444"), val = tensor(0)]; tensor var_3445_interleave_0 = const()[name = tensor("op_3445_interleave_0"), val = tensor(false)]; tensor var_3445 = concat(axis = var_3444, interleave = var_3445_interleave_0, values = (var_1154, var_1960, var_1285, var_1155))[name = tensor("op_3445")]; tensor concat_32 = const()[name = tensor("concat_32"), val = tensor([3000])]; tensor diagnostic = reshape(shape = concat_32, x = var_3445)[name = tensor("op_3448")]; tensor v_1_begin_0 = const()[name = tensor("v_1_begin_0"), val = tensor([1, 0])]; tensor v_1_end_0 = const()[name = tensor("v_1_end_0"), val = tensor([2, 1000])]; tensor v_1_end_mask_0 = const()[name = tensor("v_1_end_mask_0"), val = tensor([false, true])]; tensor v_1_squeeze_mask_0 = const()[name = tensor("v_1_squeeze_mask_0"), val = tensor([true, false])]; tensor v_1 = slice_by_index(begin = v_1_begin_0, end = v_1_end_0, end_mask = v_1_end_mask_0, squeeze_mask = v_1_squeeze_mask_0, x = candidate_interactions_transpose)[name = tensor("v_1")]; tensor expand_dims_12_axes_0 = const()[name = tensor("expand_dims_12_axes_0"), val = tensor([0])]; tensor expand_dims_12 = expand_dims(axes = expand_dims_12_axes_0, x = v_1)[name = tensor("expand_dims_12")]; tensor v_3_begin_0 = const()[name = tensor("v_3_begin_0"), val = tensor([2, 0])]; tensor v_3_end_0 = const()[name = tensor("v_3_end_0"), val = tensor([3, 1000])]; tensor v_3_end_mask_0 = const()[name = tensor("v_3_end_mask_0"), val = tensor([false, true])]; tensor v_3_squeeze_mask_0 = const()[name = tensor("v_3_squeeze_mask_0"), val = tensor([true, false])]; tensor v_3 = slice_by_index(begin = v_3_begin_0, end = v_3_end_0, end_mask = v_3_end_mask_0, squeeze_mask = v_3_squeeze_mask_0, x = candidate_interactions_transpose)[name = tensor("v_3")]; tensor expand_dims_13_axes_0 = const()[name = tensor("expand_dims_13_axes_0"), val = tensor([0])]; tensor expand_dims_13 = expand_dims(axes = expand_dims_13_axes_0, x = v_3)[name = tensor("expand_dims_13")]; tensor expand_dims_14_axes_0 = const()[name = tensor("expand_dims_14_axes_0"), val = tensor([0])]; tensor expand_dims_14 = expand_dims(axes = expand_dims_14_axes_0, x = var_464)[name = tensor("expand_dims_14")]; tensor expand_dims_15_axes_0 = const()[name = tensor("expand_dims_15_axes_0"), val = tensor([0])]; tensor expand_dims_15 = expand_dims(axes = expand_dims_15_axes_0, x = _inversed_v)[name = tensor("expand_dims_15")]; tensor var_3520_promoted = const()[name = tensor("op_3520_promoted"), val = tensor(0x0p+0)]; tensor transpose_15_perm_0 = const()[name = tensor("transpose_15_perm_0"), val = tensor([1, 0])]; tensor transpose_15 = transpose(perm = transpose_15_perm_0, x = expand_dims_13)[name = tensor("transpose_19")]; tensor var_3521 = less(x = transpose_15, y = var_3520_promoted)[name = tensor("op_3521")]; tensor var_3523_promoted_dtype_0 = const()[name = tensor("op_3523_promoted_dtype_0"), val = tensor("fp32")]; tensor reshape_62_shape_0 = const()[name = tensor("reshape_62_shape_0"), val = tensor([1, 1000])]; tensor var_3521_to_fp32 = cast(dtype = var_3523_promoted_dtype_0, x = var_3521)[name = tensor("cast_77")]; tensor reshape_62 = reshape(shape = reshape_62_shape_0, x = var_3521_to_fp32)[name = tensor("reshape_62")]; tensor candidate_ts = add(x = expand_dims_13, y = reshape_62)[name = tensor("candidate_ts")]; tensor m_interleave_0 = const()[name = tensor("m_interleave_0"), val = tensor(false)]; tensor const_9 = const()[name = tensor("const_9"), val = tensor(0)]; tensor m = concat(axis = const_9, interleave = m_interleave_0, values = (candidate_ts, expand_dims_14, expand_dims_15, expand_dims_12))[name = tensor("m")]; tensor var_3534_begin_0 = const()[name = tensor("op_3534_begin_0"), val = tensor([0, 0])]; tensor var_3534_end_0 = const()[name = tensor("op_3534_end_0"), val = tensor([1, 1000])]; tensor var_3534_end_mask_0 = const()[name = tensor("op_3534_end_mask_0"), val = tensor([false, true])]; tensor var_3534_squeeze_mask_0 = const()[name = tensor("op_3534_squeeze_mask_0"), val = tensor([true, false])]; tensor var_3534 = slice_by_index(begin = var_3534_begin_0, end = var_3534_end_0, end_mask = var_3534_end_mask_0, squeeze_mask = var_3534_squeeze_mask_0, x = m)[name = tensor("op_3534")]; tensor var_3535 = const()[name = tensor("op_3535"), val = tensor(-1)]; tensor logical_not_15 = const()[name = tensor("logical_not_15"), val = tensor(true)]; tensor indices = argsort(ascending = logical_not_15, axis = var_3535, x = var_3534)[name = tensor("indices")]; tensor sorted_history_axis_0 = const()[name = tensor("sorted_history_axis_0"), val = tensor(0)]; tensor transpose_14_perm_0 = const()[name = tensor("transpose_14_perm_0"), val = tensor([1, 0])]; tensor transpose_14 = transpose(perm = transpose_14_perm_0, x = m)[name = tensor("transpose_18")]; tensor sorted_history = gather(axis = sorted_history_axis_0, indices = indices, x = transpose_14)[name = tensor("sorted_history")]; tensor var_3546_perm_0 = const()[name = tensor("op_3546_perm_0"), val = tensor([1, 0])]; tensor var_3551_begin_0 = const()[name = tensor("op_3551_begin_0"), val = tensor([1, 0])]; tensor var_3551_end_0 = const()[name = tensor("op_3551_end_0"), val = tensor([4, 1000])]; tensor var_3551_end_mask_0 = const()[name = tensor("op_3551_end_mask_0"), val = tensor([true, true])]; tensor var_3546 = transpose(perm = var_3546_perm_0, x = sorted_history)[name = tensor("transpose_17")]; tensor var_3551 = slice_by_index(begin = var_3551_begin_0, end = var_3551_end_0, end_mask = var_3551_end_mask_0, x = var_3546)[name = tensor("op_3551")]; tensor var_3554_perm_0 = const()[name = tensor("op_3554_perm_0"), val = tensor([1, 0])]; tensor var_3559_begin_0 = const()[name = tensor("op_3559_begin_0"), val = tensor([0, 0])]; tensor var_3559_end_0 = const()[name = tensor("op_3559_end_0"), val = tensor([10, 3])]; tensor var_3559_end_mask_0 = const()[name = tensor("op_3559_end_mask_0"), val = tensor([false, true])]; tensor var_3554 = transpose(perm = var_3554_perm_0, x = var_3551)[name = tensor("transpose_16")]; tensor anonymizedHistory = slice_by_index(begin = var_3559_begin_0, end = var_3559_end_0, end_mask = var_3559_end_mask_0, x = var_3554)[name = tensor("op_3559")]; tensor var_3562_begin_0 = const()[name = tensor("op_3562_begin_0"), val = tensor([0])]; tensor var_3562_end_0 = const()[name = tensor("op_3562_end_0"), val = tensor([1])]; tensor var_3562_end_mask_0 = const()[name = tensor("op_3562_end_mask_0"), val = tensor([false])]; tensor var_3562_squeeze_mask_0 = const()[name = tensor("op_3562_squeeze_mask_0"), val = tensor([true])]; tensor var_3562 = slice_by_index(begin = var_3562_begin_0, end = var_3562_end_0, end_mask = var_3562_end_mask_0, squeeze_mask = var_3562_squeeze_mask_0, x = actionId)[name = tensor("op_3562")]; tensor var_3563_promoted = const()[name = tensor("op_3563_promoted"), val = tensor(0x1p+2)]; tensor var_3564 = equal(x = var_3562, y = var_3563_promoted)[name = tensor("op_3564")]; tensor var_3564_promoted_dtype_0 = const()[name = tensor("op_3564_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3564_promoted = cast(dtype = var_3564_promoted_dtype_0, x = var_3564)[name = tensor("cast_76")]; tensor var_3565 = mul(x = var_121_promoted, y = var_3564_promoted)[name = tensor("op_3565")]; tensor var_3568_promoted = const()[name = tensor("op_3568_promoted"), val = tensor(0x1p+0)]; tensor var_3570 = sub(x = var_3568_promoted, y = var_3565)[name = tensor("op_3570")]; tensor var_3574_promoted = const()[name = tensor("op_3574_promoted"), val = tensor(0x1p+1)]; tensor var_3575 = equal(x = var_3562, y = var_3574_promoted)[name = tensor("op_3575")]; tensor var_3575_promoted_dtype_0 = const()[name = tensor("op_3575_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3575_promoted = cast(dtype = var_3575_promoted_dtype_0, x = var_3575)[name = tensor("cast_75")]; tensor var_3576 = mul(x = forced_1, y = var_3575_promoted)[name = tensor("op_3576")]; tensor var_3579_promoted = const()[name = tensor("op_3579_promoted"), val = tensor(0x1p+0)]; tensor var_3581 = sub(x = var_3579_promoted, y = var_3576)[name = tensor("op_3581")]; tensor var_3582 = mul(x = var_3570, y = var_3581)[name = tensor("op_3582")]; tensor var_3586_promoted = const()[name = tensor("op_3586_promoted"), val = tensor(0x1.8p+1)]; tensor var_3587 = equal(x = var_3562, y = var_3586_promoted)[name = tensor("op_3587")]; tensor var_3587_promoted_dtype_0 = const()[name = tensor("op_3587_promoted_dtype_0"), val = tensor("fp32")]; tensor var_3587_promoted = cast(dtype = var_3587_promoted_dtype_0, x = var_3587)[name = tensor("cast_74")]; tensor var_3588 = mul(x = forced_parameter_confirm, y = var_3587_promoted)[name = tensor("op_3588")]; tensor var_3591_promoted = const()[name = tensor("op_3591_promoted"), val = tensor(0x1p+0)]; tensor var_3593 = sub(x = var_3591_promoted, y = var_3588)[name = tensor("op_3593")]; tensor var_3594 = mul(x = var_3582, y = var_3593)[name = tensor("op_3594")]; tensor var_3595_promoted = const()[name = tensor("op_3595_promoted"), val = tensor(0x1p+0)]; tensor forcedPrompt = sub(x = var_3595_promoted, y = var_3594)[name = tensor("op_3597")]; tensor shadowActionCandidates = transpose(perm = var_3435_perm_0, x = gather_nd_1)[name = tensor("transpose_20")]; tensor actionCandidates = transpose(perm = var_3346_perm_0, x = gather_nd_0)[name = tensor("transpose_21")]; } -> (actionId, actionCandidates, shadowActionId, shadowActionCandidates, rankings, diagnostic, anonymizedHistory, forcedPrompt); }