program(1.0) [buildInfo = dict, tensor>({{"coremlc-component-MIL", "3520.2.1"}, {"coremlc-version", "3520.2.1"}, {"coremltools-component-torch", "2.6.0"}, {"coremltools-source-dialect", "TorchScript"}, {"coremltools-version", "8.2.2"}, {"mldb_token", "mldb-wrxg5d7ao8"}})] { func main(tensor alreadyPrompted, tensor candidateInteractions, tensor candidate_risk, tensor component, tensor deviceContext, tensor forcedPromptRate, tensor isResolved, tensor parameterName, tensor riskLevel, tensor similarityScores, tensor tupleInteractions_alignment, tensor tupleInteractions_candidates, tensor tuples) [FlexibleShapeInformation = tuple, dict, tensor>>, tuple, dict, list, ?>>>>((("DefaultShapes", {{"alreadyPrompted", [1]}, {"candidateInteractions", [2, 8]}, {"candidate_risk", [1, 1]}, {"isResolved", [1]}, {"parameterName", [1]}, {"riskLevel", [1, 2]}, {"similarityScores", [1, 3]}, {"tupleInteractions_alignment", [1, 1]}, {"tupleInteractions_candidates", [1, 1]}, {"tuples", [9, 1, 1]}}), ("RangeDims", {{"alreadyPrompted", [[1, 15]]}, {"candidateInteractions", [[2, 10000], [8, 8]]}, {"candidate_risk", [[1, 10000000], [1, 15]]}, {"isResolved", [[1, 15]]}, {"parameterName", [[1, 15]]}, {"riskLevel", [[1, 10000000], [2, 2]]}, {"similarityScores", [[1, 10], [3, 3]]}, {"tupleInteractions_alignment", [[1, 1000], [1, 1000]]}, {"tupleInteractions_candidates", [[1, 1000], [1, 1000]]}, {"tuples", [[9, 9], [1, 10000000], [1, 15]]}})))] { tensor var_18_perm_0 = const()[name = tensor("op_18_perm_0"), val = tensor([1, 0])]; tensor alignments_begin_0 = const()[name = tensor("alignments_begin_0"), val = tensor([1, 0])]; tensor alignments_end_0 = const()[name = tensor("alignments_end_0"), val = tensor([2, 0])]; tensor alignments_end_mask_0 = const()[name = tensor("alignments_end_mask_0"), val = tensor([false, true])]; tensor alignments_squeeze_mask_0 = const()[name = tensor("alignments_squeeze_mask_0"), val = tensor([true, false])]; tensor var_18 = transpose(perm = var_18_perm_0, x = candidateInteractions)[name = tensor("transpose_12")]; tensor alignments = slice_by_index(begin = alignments_begin_0, end = alignments_end_0, end_mask = alignments_end_mask_0, squeeze_mask = alignments_squeeze_mask_0, x = var_18)[name = tensor("alignments")]; tensor var_23_axes_0 = const()[name = tensor("op_23_axes_0"), val = tensor([1])]; tensor var_23 = expand_dims(axes = var_23_axes_0, x = alignments)[name = tensor("op_23")]; tensor var_24 = const()[name = tensor("op_24"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(64)))]; tensor var_26 = sub(x = var_23, y = var_24)[name = tensor("op_26")]; tensor var_27 = abs(x = var_26)[name = tensor("op_27")]; tensor var_28 = const()[name = tensor("op_28"), val = tensor(0x1.0624dep-10)]; tensor var_29 = less(x = var_27, y = var_28)[name = tensor("op_29")]; tensor var_29_promoted_dtype_0 = const()[name = tensor("op_29_promoted_dtype_0"), val = tensor("fp32")]; tensor transpose_0 = const()[name = tensor("transpose_0"), val = tensor(BLOBFILE(path = tensor("@model_path/weights/weight.bin"), offset = tensor(256)))]; tensor new_alignments_bias_0 = const()[name = tensor("new_alignments_bias_0"), val = tensor([0x0p+0])]; tensor var_29_promoted = cast(dtype = var_29_promoted_dtype_0, x = var_29)[name = tensor("cast_77")]; tensor new_alignments = linear(bias = new_alignments_bias_0, weight = transpose_0, x = var_29_promoted)[name = tensor("new_alignments")]; tensor var_38 = squeeze(x = new_alignments)[name = tensor("op_38")]; tensor concat_0 = const()[name = tensor("concat_0"), val = tensor([1, 0])]; tensor concat_1 = const()[name = tensor("concat_1"), val = tensor([0, 0])]; tensor new_candidate_interactions_internal_tensor_assign_1_stride_0 = const()[name = tensor("new_candidate_interactions_internal_tensor_assign_1_stride_0"), val = tensor([1, 1])]; tensor new_candidate_interactions_internal_tensor_assign_1_begin_mask_0 = const()[name = tensor("new_candidate_interactions_internal_tensor_assign_1_begin_mask_0"), val = tensor([false, true])]; tensor new_candidate_interactions_internal_tensor_assign_1_end_mask_0 = const()[name = tensor("new_candidate_interactions_internal_tensor_assign_1_end_mask_0"), val = tensor([false, true])]; tensor new_candidate_interactions_internal_tensor_assign_1_squeeze_mask_0 = const()[name = tensor("new_candidate_interactions_internal_tensor_assign_1_squeeze_mask_0"), val = tensor([true, false])]; tensor shape_6 = shape(x = var_18)[name = tensor("shape_6")]; tensor reduce_prod_0_keep_dims_0 = const()[name = tensor("reduce_prod_0_keep_dims_0"), val = tensor(false)]; tensor reduce_prod_0 = reduce_prod(keep_dims = reduce_prod_0_keep_dims_0, x = shape_6)[name = tensor("reduce_prod_0")]; tensor range_1d_0_start_0 = const()[name = tensor("range_1d_0_start_0"), val = tensor(0)]; tensor range_1d_0_step_0 = const()[name = tensor("range_1d_0_step_0"), val = tensor(1)]; tensor range_1d_0 = range_1d(end = reduce_prod_0, start = range_1d_0_start_0, step = range_1d_0_step_0)[name = tensor("range_1d_0")]; tensor reshape_0 = reshape(shape = shape_6, x = range_1d_0)[name = tensor("reshape_0")]; tensor slice_by_index_6 = slice_by_index(begin = concat_0, begin_mask = new_candidate_interactions_internal_tensor_assign_1_begin_mask_0, end = concat_1, end_mask = new_candidate_interactions_internal_tensor_assign_1_end_mask_0, squeeze_mask = new_candidate_interactions_internal_tensor_assign_1_squeeze_mask_0, stride = new_candidate_interactions_internal_tensor_assign_1_stride_0, x = reshape_0)[name = tensor("slice_by_index_6")]; tensor reshape_3_shape_0 = const()[name = tensor("reshape_3_shape_0"), val = tensor([-1])]; tensor reshape_3 = reshape(shape = reshape_3_shape_0, x = var_18)[name = tensor("reshape_3")]; tensor scatter_0_mode_0 = const()[name = tensor("scatter_0_mode_0"), val = tensor("update")]; tensor scatter_0_axis_0 = const()[name = tensor("scatter_0_axis_0"), val = tensor(0)]; tensor scatter_0 = scatter(axis = scatter_0_axis_0, data = reshape_3, indices = slice_by_index_6, mode = scatter_0_mode_0, updates = var_38)[name = tensor("scatter_0")]; tensor reshape_4 = reshape(shape = shape_6, x = scatter_0)[name = tensor("reshape_4")]; tensor var_49_begin_0 = const()[name = tensor("op_49_begin_0"), val = tensor([0, 0, 0])]; tensor var_49_end_0 = const()[name = tensor("op_49_end_0"), val = tensor([1, 0, 0])]; tensor var_49_end_mask_0 = const()[name = tensor("op_49_end_mask_0"), val = tensor([false, true, true])]; tensor var_49_squeeze_mask_0 = const()[name = tensor("op_49_squeeze_mask_0"), val = tensor([true, false, false])]; tensor var_49 = slice_by_index(begin = var_49_begin_0, end = var_49_end_0, end_mask = var_49_end_mask_0, squeeze_mask = var_49_squeeze_mask_0, x = tuples)[name = tensor("op_49")]; tensor var_50 = const()[name = tensor("op_50"), val = tensor(-0x1p-1)]; tensor var_51 = greater(x = var_49, y = var_50)[name = tensor("op_51")]; tensor not_padded_mask_promoted_dtype_0 = const()[name = tensor("not_padded_mask_promoted_dtype_0"), val = tensor("fp32")]; tensor var_51_to_fp32 = cast(dtype = not_padded_mask_promoted_dtype_0, x = var_51)[name = tensor("cast_76")]; tensor x = mul(x = var_49, y = var_51_to_fp32)[name = tensor("x")]; tensor var_60_begin_0 = const()[name = tensor("op_60_begin_0"), val = tensor([4, 0])]; tensor var_60_end_0 = const()[name = tensor("op_60_end_0"), val = tensor([5, 0])]; tensor var_60_end_mask_0 = const()[name = tensor("op_60_end_mask_0"), val = tensor([false, true])]; tensor var_60_squeeze_mask_0 = const()[name = tensor("op_60_squeeze_mask_0"), val = tensor([true, false])]; tensor var_60 = slice_by_index(begin = var_60_begin_0, end = var_60_end_0, end_mask = var_60_end_mask_0, squeeze_mask = var_60_squeeze_mask_0, x = reshape_4)[name = tensor("op_60")]; tensor var_61 = const()[name = tensor("op_61"), val = tensor(-0x1p-1)]; tensor var_62 = less(x = var_60, y = var_61)[name = tensor("op_62")]; tensor var_62_promoted_dtype_0 = const()[name = tensor("op_62_promoted_dtype_0"), val = tensor("fp32")]; tensor var_70 = const()[name = tensor("op_70"), val = tensor(0x1.0609bep+3)]; tensor var_71 = sub(x = var_60, y = var_70)[name = tensor("op_71")]; tensor cand_time_1 = exp(x = var_71)[name = tensor("cand_time_1")]; tensor var_73_promoted = const()[name = tensor("op_73_promoted"), val = tensor(0x1p+0)]; tensor var_62_promoted = cast(dtype = var_62_promoted_dtype_0, x = var_62)[name = tensor("cast_75")]; tensor var_75 = sub(x = var_73_promoted, y = var_62_promoted)[name = tensor("op_75")]; tensor var_76 = const()[name = tensor("op_76"), val = tensor(0x1.0c1524p-2)]; tensor var_77 = mul(x = cand_time_1, y = var_76)[name = tensor("op_77")]; tensor var_78 = cos(x = var_77)[name = tensor("op_78")]; tensor var_79_promoted = const()[name = tensor("op_79_promoted"), val = tensor(-0x1p+0)]; tensor var_80 = mul(x = var_78, y = var_79_promoted)[name = tensor("op_80")]; tensor var_81 = const()[name = tensor("op_81"), val = tensor(0x1p-1)]; tensor var_82 = mul(x = var_80, y = var_81)[name = tensor("op_82")]; tensor var_84 = const()[name = tensor("op_84"), val = tensor(0x1p-1)]; tensor var_85 = add(x = var_82, y = var_84)[name = tensor("op_85")]; tensor var_86 = mul(x = var_75, y = var_85)[name = tensor("op_86")]; tensor var_87_promoted = const()[name = tensor("op_87_promoted"), val = tensor(-0x1p+0)]; tensor var_88 = mul(x = var_62_promoted, y = var_87_promoted)[name = tensor("op_88")]; tensor cand_freq_1 = add(x = var_86, y = var_88)[name = tensor("cand_freq_1")]; tensor var_92_axes_0 = const()[name = tensor("op_92_axes_0"), val = tensor([0])]; tensor var_92 = expand_dims(axes = var_92_axes_0, x = cand_freq_1)[name = tensor("op_92")]; tensor var_94 = const()[name = tensor("op_94"), val = tensor(0)]; tensor var_95_interleave_0 = const()[name = tensor("op_95_interleave_0"), val = tensor(false)]; tensor var_95 = concat(axis = var_94, interleave = var_95_interleave_0, values = (reshape_4, var_92))[name = tensor("op_95")]; tensor var_104_begin_0 = const()[name = tensor("op_104_begin_0"), val = tensor([4, 0])]; tensor var_104_end_0 = const()[name = tensor("op_104_end_0"), val = tensor([5, 0])]; tensor var_104_end_mask_0 = const()[name = tensor("op_104_end_mask_0"), val = tensor([false, true])]; tensor var_104_squeeze_mask_0 = const()[name = tensor("op_104_squeeze_mask_0"), val = tensor([true, false])]; tensor var_104 = slice_by_index(begin = var_104_begin_0, end = var_104_end_0, end_mask = var_104_end_mask_0, squeeze_mask = var_104_squeeze_mask_0, x = var_95)[name = tensor("op_104")]; tensor var_105 = const()[name = tensor("op_105"), val = tensor(-0x1p-1)]; tensor var_106 = less(x = var_104, y = var_105)[name = tensor("op_106")]; tensor var_106_promoted_dtype_0 = const()[name = tensor("op_106_promoted_dtype_0"), val = tensor("fp32")]; tensor var_114 = const()[name = tensor("op_114"), val = tensor(0x1.0609bep+3)]; tensor var_115 = sub(x = var_104, y = var_114)[name = tensor("op_115")]; tensor cand_time = exp(x = var_115)[name = tensor("cand_time")]; tensor var_117_promoted = const()[name = tensor("op_117_promoted"), val = tensor(0x1p+0)]; tensor var_106_promoted = cast(dtype = var_106_promoted_dtype_0, x = var_106)[name = tensor("cast_74")]; tensor var_119 = sub(x = var_117_promoted, y = var_106_promoted)[name = tensor("op_119")]; tensor var_120 = const()[name = tensor("op_120"), val = tensor(0x1.32614ep-5)]; tensor var_121 = mul(x = cand_time, y = var_120)[name = tensor("op_121")]; tensor var_122 = cos(x = var_121)[name = tensor("op_122")]; tensor var_123_promoted = const()[name = tensor("op_123_promoted"), val = tensor(-0x1p+0)]; tensor var_124 = mul(x = var_122, y = var_123_promoted)[name = tensor("op_124")]; tensor var_125 = const()[name = tensor("op_125"), val = tensor(0x1p-1)]; tensor var_126 = mul(x = var_124, y = var_125)[name = tensor("op_126")]; tensor var_128 = const()[name = tensor("op_128"), val = tensor(0x1p-1)]; tensor var_129 = add(x = var_126, y = var_128)[name = tensor("op_129")]; tensor var_130 = mul(x = var_119, y = var_129)[name = tensor("op_130")]; tensor var_131_promoted = const()[name = tensor("op_131_promoted"), val = tensor(-0x1p+0)]; tensor var_132 = mul(x = var_106_promoted, y = var_131_promoted)[name = tensor("op_132")]; tensor cand_freq = add(x = var_130, y = var_132)[name = tensor("cand_freq")]; tensor var_136_axes_0 = const()[name = tensor("op_136_axes_0"), val = tensor([0])]; tensor var_136 = expand_dims(axes = var_136_axes_0, x = cand_freq)[name = tensor("op_136")]; tensor var_138 = const()[name = tensor("op_138"), val = tensor(0)]; tensor var_139_interleave_0 = const()[name = tensor("op_139_interleave_0"), val = tensor(false)]; tensor var_139 = concat(axis = var_138, interleave = var_139_interleave_0, values = (var_95, var_136))[name = tensor("op_139")]; tensor candidate_interactions_9_perm_0 = const()[name = tensor("candidate_interactions_9_perm_0"), val = tensor([1, 0])]; tensor var_143 = const()[name = tensor("op_143"), val = tensor([0x0p+0, 0x0p+0])]; tensor var_145 = const()[name = tensor("op_145"), val = tensor(0)]; tensor device_context_interleave_0 = const()[name = tensor("device_context_interleave_0"), val = tensor(false)]; tensor device_context = concat(axis = var_145, interleave = device_context_interleave_0, values = (deviceContext, var_143))[name = tensor("device_context")]; tensor var_149_perm_0 = const()[name = tensor("op_149_perm_0"), val = tensor([1, 0])]; tensor var_155_begin_0 = const()[name = tensor("op_155_begin_0"), val = tensor([4, 0])]; tensor var_155_end_0 = const()[name = tensor("op_155_end_0"), val = tensor([5, 0])]; tensor var_155_end_mask_0 = const()[name = tensor("op_155_end_mask_0"), val = tensor([false, true])]; tensor var_155_squeeze_mask_0 = const()[name = tensor("op_155_squeeze_mask_0"), val = tensor([true, false])]; tensor candidate_interactions_9 = transpose(perm = candidate_interactions_9_perm_0, x = var_139)[name = tensor("transpose_11")]; tensor var_149 = transpose(perm = var_149_perm_0, x = candidate_interactions_9)[name = tensor("transpose_10")]; tensor var_155 = slice_by_index(begin = var_155_begin_0, end = var_155_end_0, end_mask = var_155_end_mask_0, squeeze_mask = var_155_squeeze_mask_0, x = var_149)[name = tensor("op_155")]; tensor var_156_keep_dims_0 = const()[name = tensor("op_156_keep_dims_0"), val = tensor(false)]; tensor var_156 = reduce_max(keep_dims = var_156_keep_dims_0, x = var_155)[name = tensor("op_156")]; tensor var_158_promoted = const()[name = tensor("op_158_promoted"), val = tensor(0x1.cp+2)]; tensor time_correction = sub(x = var_156, y = var_158_promoted)[name = tensor("time_correction")]; tensor var_164 = sub(x = var_155, y = time_correction)[name = tensor("op_164")]; tensor var_165 = exp(x = var_164)[name = tensor("op_165")]; tensor concat_2 = const()[name = tensor("concat_2"), val = tensor([4, 0])]; tensor concat_3 = const()[name = tensor("concat_3"), val = tensor([0, 0])]; tensor candidate_by_column_internal_tensor_assign_1_stride_0 = const()[name = tensor("candidate_by_column_internal_tensor_assign_1_stride_0"), val = tensor([1, 1])]; tensor candidate_by_column_internal_tensor_assign_1_begin_mask_0 = const()[name = tensor("candidate_by_column_internal_tensor_assign_1_begin_mask_0"), val = tensor([false, true])]; tensor candidate_by_column_internal_tensor_assign_1_end_mask_0 = const()[name = tensor("candidate_by_column_internal_tensor_assign_1_end_mask_0"), val = tensor([false, true])]; tensor candidate_by_column_internal_tensor_assign_1_squeeze_mask_0 = const()[name = tensor("candidate_by_column_internal_tensor_assign_1_squeeze_mask_0"), val = tensor([true, false])]; tensor shape_7 = shape(x = var_149)[name = tensor("shape_7")]; tensor reduce_prod_1_keep_dims_0 = const()[name = tensor("reduce_prod_1_keep_dims_0"), val = tensor(false)]; tensor reduce_prod_1 = reduce_prod(keep_dims = reduce_prod_1_keep_dims_0, x = shape_7)[name = tensor("reduce_prod_1")]; tensor range_1d_1_start_0 = const()[name = tensor("range_1d_1_start_0"), val = tensor(0)]; tensor range_1d_1_step_0 = const()[name = tensor("range_1d_1_step_0"), val = tensor(1)]; tensor range_1d_1 = range_1d(end = reduce_prod_1, start = range_1d_1_start_0, step = range_1d_1_step_0)[name = tensor("range_1d_1")]; tensor reshape_5 = reshape(shape = shape_7, x = range_1d_1)[name = tensor("reshape_5")]; tensor slice_by_index_7 = slice_by_index(begin = concat_2, begin_mask = candidate_by_column_internal_tensor_assign_1_begin_mask_0, end = concat_3, end_mask = candidate_by_column_internal_tensor_assign_1_end_mask_0, squeeze_mask = candidate_by_column_internal_tensor_assign_1_squeeze_mask_0, stride = candidate_by_column_internal_tensor_assign_1_stride_0, x = reshape_5)[name = tensor("slice_by_index_7")]; tensor reshape_8_shape_0 = const()[name = tensor("reshape_8_shape_0"), val = tensor([-1])]; tensor reshape_8 = reshape(shape = reshape_8_shape_0, x = var_149)[name = tensor("reshape_8")]; tensor scatter_1_mode_0 = const()[name = tensor("scatter_1_mode_0"), val = tensor("update")]; tensor scatter_1_axis_0 = const()[name = tensor("scatter_1_axis_0"), val = tensor(0)]; tensor scatter_1 = scatter(axis = scatter_1_axis_0, data = reshape_8, indices = slice_by_index_7, mode = scatter_1_mode_0, updates = var_165)[name = tensor("scatter_1")]; tensor reshape_9 = reshape(shape = shape_7, x = scatter_1)[name = tensor("reshape_9")]; tensor candidate_interactions_11_perm_0 = const()[name = tensor("candidate_interactions_11_perm_0"), val = tensor([1, 0])]; tensor var_174 = const()[name = tensor("op_174"), val = tensor(-0x1p-1)]; tensor var_175 = greater(x = candidate_interactions_9, y = var_174)[name = tensor("op_175")]; tensor var_175_promoted_dtype_0 = const()[name = tensor("op_175_promoted_dtype_0"), val = tensor("fp32")]; tensor var_175_promoted = cast(dtype = var_175_promoted_dtype_0, x = var_175)[name = tensor("cast_73")]; tensor candidate_interactions_11 = transpose(perm = candidate_interactions_11_perm_0, x = reshape_9)[name = tensor("transpose_9")]; tensor var_176 = mul(x = var_175_promoted, y = candidate_interactions_11)[name = tensor("op_176")]; tensor var_177 = const()[name = tensor("op_177"), val = tensor(-0x1p-1)]; tensor var_178 = less(x = candidate_interactions_9, y = var_177)[name = tensor("op_178")]; tensor var_179 = const()[name = tensor("op_179"), val = tensor(-1)]; tensor var_178_promoted_dtype_0 = const()[name = tensor("op_178_promoted_dtype_0"), val = tensor("int32")]; tensor var_178_promoted = cast(dtype = var_178_promoted_dtype_0, x = var_178)[name = tensor("cast_72")]; tensor var_180 = mul(x = var_178_promoted, y = var_179)[name = tensor("op_180")]; tensor var_180_promoted_dtype_0 = const()[name = tensor("op_180_promoted_dtype_0"), val = tensor("fp32")]; tensor var_180_promoted = cast(dtype = var_180_promoted_dtype_0, x = var_180)[name = tensor("cast_71")]; tensor candidate_interactions = add(x = var_176, y = var_180_promoted)[name = tensor("candidate_interactions")]; tensor candidate_interactions_transpose_perm_0 = const()[name = tensor("candidate_interactions_transpose_perm_0"), val = tensor([1, 0])]; tensor var_188 = const()[name = tensor("op_188"), val = tensor(0x1p-1)]; tensor candidate_interactions_transpose = transpose(perm = candidate_interactions_transpose_perm_0, x = candidate_interactions)[name = tensor("transpose_8")]; tensor var_189 = equal(x = candidate_interactions_transpose, y = var_188)[name = tensor("op_189")]; tensor cast_0_dtype_0 = const()[name = tensor("cast_0_dtype_0"), val = tensor("fp32")]; tensor mask_1_axes_0 = const()[name = tensor("mask_1_axes_0"), val = tensor([0])]; tensor mask_1_keep_dims_0 = const()[name = tensor("mask_1_keep_dims_0"), val = tensor(false)]; tensor cast_0 = cast(dtype = cast_0_dtype_0, x = var_189)[name = tensor("cast_70")]; tensor mask_1 = reduce_sum(axes = mask_1_axes_0, keep_dims = mask_1_keep_dims_0, x = cast_0)[name = tensor("mask_1")]; tensor var_200_begin_0 = const()[name = tensor("op_200_begin_0"), val = tensor([0, 0])]; tensor var_200_end_0 = const()[name = tensor("op_200_end_0"), val = tensor([1, 0])]; tensor var_200_end_mask_0 = const()[name = tensor("op_200_end_mask_0"), val = tensor([false, true])]; tensor var_200_squeeze_mask_0 = const()[name = tensor("op_200_squeeze_mask_0"), val = tensor([true, false])]; tensor var_200 = slice_by_index(begin = var_200_begin_0, end = var_200_end_0, end_mask = var_200_end_mask_0, squeeze_mask = var_200_squeeze_mask_0, x = candidate_interactions_transpose)[name = tensor("op_200")]; tensor var_201_promoted = const()[name = tensor("op_201_promoted"), val = tensor(0x0p+0)]; tensor var_202 = mul(x = var_200, y = var_201_promoted)[name = tensor("op_202")]; tensor zero_slice_1_axes_0 = const()[name = tensor("zero_slice_1_axes_0"), val = tensor([0])]; tensor zero_slice_1 = expand_dims(axes = zero_slice_1_axes_0, x = var_202)[name = tensor("zero_slice_1")]; tensor var_206 = const()[name = tensor("op_206"), val = tensor(0)]; tensor context_feedback_3_interleave_0 = const()[name = tensor("context_feedback_3_interleave_0"), val = tensor(false)]; tensor context_feedback_3 = concat(axis = var_206, interleave = context_feedback_3_interleave_0, values = (candidate_interactions_transpose, zero_slice_1))[name = tensor("context_feedback_3")]; tensor var_212_begin_0 = const()[name = tensor("op_212_begin_0"), val = tensor([4, 0])]; tensor var_212_end_0 = const()[name = tensor("op_212_end_0"), val = tensor([10, 0])]; tensor var_212_end_mask_0 = const()[name = tensor("op_212_end_mask_0"), val = tensor([false, true])]; tensor var_212 = slice_by_index(begin = var_212_begin_0, end = var_212_end_0, end_mask = var_212_end_mask_0, x = context_feedback_3)[name = tensor("op_212")]; tensor context_feedback_5_perm_0 = const()[name = tensor("context_feedback_5_perm_0"), val = tensor([1, 0])]; tensor context_feedback_5 = transpose(perm = context_feedback_5_perm_0, x = var_212)[name = tensor("transpose_7")]; tensor var_217_shape = shape(x = context_feedback_5)[name = tensor("op_217_shape")]; tensor gather_0_indices_0 = const()[name = tensor("gather_0_indices_0"), val = tensor(0)]; tensor gather_0_axis_0 = const()[name = tensor("gather_0_axis_0"), val = tensor(0)]; tensor gather_0 = gather(axis = gather_0_axis_0, indices = gather_0_indices_0, x = var_217_shape)[name = tensor("gather_0")]; tensor var_223 = const()[name = tensor("op_223"), val = tensor(6)]; tensor concat_4_axis_0 = const()[name = tensor("concat_4_axis_0"), val = tensor(0)]; tensor concat_4_interleave_0 = const()[name = tensor("concat_4_interleave_0"), val = tensor(false)]; tensor concat_4 = concat(axis = concat_4_axis_0, interleave = concat_4_interleave_0, values = (gather_0, var_223))[name = tensor("concat_4")]; tensor fill_0_value_0 = const()[name = tensor("fill_0_value_0"), val = tensor(0x1p+0)]; tensor fill_0 = fill(shape = concat_4, value = fill_0_value_0)[name = tensor("fill_0")]; tensor var_231_axes_0 = const()[name = tensor("op_231_axes_0"), val = tensor([-1])]; tensor var_231 = expand_dims(axes = var_231_axes_0, x = mask_1)[name = tensor("op_231")]; tensor var_232 = mul(x = fill_0, y = var_231)[name = tensor("op_232")]; tensor var_233_promoted = const()[name = tensor("op_233_promoted"), val = tensor(-0x1p+0)]; tensor var_234 = mul(x = var_232, y = var_233_promoted)[name = tensor("op_234")]; tensor var_235_promoted = const()[name = tensor("op_235_promoted"), val = tensor(0x1p+0)]; tensor var_237 = sub(x = var_235_promoted, y = mask_1)[name = tensor("op_237")]; tensor var_239_axes_0 = const()[name = tensor("op_239_axes_0"), val = tensor([-1])]; tensor var_239 = expand_dims(axes = var_239_axes_0, x = var_237)[name = tensor("op_239")]; tensor var_240 = mul(x = context_feedback_5, y = var_239)[name = tensor("op_240")]; tensor context_feedback_7 = add(x = var_234, y = var_240)[name = tensor("context_feedback_7")]; tensor var_244_axes_0 = const()[name = tensor("op_244_axes_0"), val = tensor([0])]; tensor var_244 = expand_dims(axes = var_244_axes_0, x = device_context)[name = tensor("op_244")]; tensor var_246 = const()[name = tensor("op_246"), val = tensor(0)]; tensor context_feedback_9_interleave_0 = const()[name = tensor("context_feedback_9_interleave_0"), val = tensor(false)]; tensor context_feedback_9 = concat(axis = var_246, interleave = context_feedback_9_interleave_0, values = (context_feedback_7, var_244))[name = tensor("context_feedback_9")]; tensor var_248 = const()[name = tensor("op_248"), val = tensor(-0x1.8p+0)]; tensor var_249 = greater(x = context_feedback_9, y = var_248)[name = tensor("op_249")]; tensor var_249_promoted_dtype_0 = const()[name = tensor("op_249_promoted_dtype_0"), val = tensor("fp32")]; tensor var_253 = const()[name = tensor("op_253"), val = tensor(-0x1p-1)]; tensor var_254 = less(x = context_feedback_9, y = var_253)[name = tensor("op_254")]; tensor var_254_promoted_dtype_0 = const()[name = tensor("op_254_promoted_dtype_0"), val = tensor("fp32")]; tensor var_254_promoted = cast(dtype = var_254_promoted_dtype_0, x = var_254)[name = tensor("cast_68")]; tensor var_249_promoted = cast(dtype = var_249_promoted_dtype_0, x = var_249)[name = tensor("cast_69")]; tensor var_258 = mul(x = var_249_promoted, y = var_254_promoted)[name = tensor("op_258")]; tensor var_259_promoted = const()[name = tensor("op_259_promoted"), val = tensor(0x1p+0)]; tensor padded_context_mask = sub(x = var_259_promoted, y = var_258)[name = tensor("padded_context_mask")]; tensor masked_context_1 = mul(x = padded_context_mask, y = context_feedback_9)[name = tensor("masked_context_1")]; tensor var_263 = abs(x = masked_context_1)[name = tensor("op_263")]; tensor scaled_keep_dims_0 = const()[name = tensor("scaled_keep_dims_0"), val = tensor(false)]; tensor scaled = reduce_max(keep_dims = scaled_keep_dims_0, x = var_263)[name = tensor("scaled")]; tensor var_266 = const()[name = tensor("op_266"), val = tensor(0x1.0624dep-10)]; tensor var_267 = add(x = scaled, y = var_266)[name = tensor("op_267")]; tensor var_268 = real_div(x = masked_context_1, y = var_267)[name = tensor("op_268")]; tensor var_269_promoted = const()[name = tensor("op_269_promoted"), val = tensor(0x1.f4p+9)]; tensor masked_context = mul(x = var_268, y = var_269_promoted)[name = tensor("masked_context")]; tensor modified_actual_sum_axes_0 = const()[name = tensor("modified_actual_sum_axes_0"), val = tensor([0])]; tensor modified_actual_sum_keep_dims_0 = const()[name = tensor("modified_actual_sum_keep_dims_0"), val = tensor(false)]; tensor modified_actual_sum = reduce_sum(axes = modified_actual_sum_axes_0, keep_dims = modified_actual_sum_keep_dims_0, x = masked_context)[name = tensor("modified_actual_sum")]; tensor modified_actual_count_axes_0 = const()[name = tensor("modified_actual_count_axes_0"), val = tensor([0])]; tensor modified_actual_count_keep_dims_0 = const()[name = tensor("modified_actual_count_keep_dims_0"), val = tensor(false)]; tensor modified_actual_count = reduce_sum(axes = modified_actual_count_axes_0, keep_dims = modified_actual_count_keep_dims_0, x = padded_context_mask)[name = tensor("modified_actual_count")]; tensor var_284 = const()[name = tensor("op_284"), val = tensor(0x1.0624dep-10)]; tensor var_285 = add(x = modified_actual_count, y = var_284)[name = tensor("op_285")]; tensor adjusted_mean_1 = real_div(x = modified_actual_sum, y = var_285)[name = tensor("adjusted_mean_1")]; tensor _inversed_288_y_0 = const()[name = tensor("_inversed_288_y_0"), val = tensor(0x1.0624dep-10)]; tensor _inversed_288 = mul(x = adjusted_mean_1, y = _inversed_288_y_0)[name = tensor("_inversed_288")]; tensor var_289 = mul(x = _inversed_288, y = scaled)[name = tensor("op_289")]; tensor not_equal_0 = not_equal(x = var_289, y = var_289)[name = tensor("not_equal_0")]; tensor cast_7_dtype_0 = const()[name = tensor("cast_7_dtype_0"), val = tensor("int32")]; tensor cast_7 = cast(dtype = cast_7_dtype_0, x = not_equal_0)[name = tensor("cast_67")]; tensor non_zero_0 = non_zero(x = cast_7)[name = tensor("non_zero_0")]; tensor expand_dims_0 = const()[name = tensor("expand_dims_0"), val = tensor([0x0p+0])]; tensor shape_0 = shape(x = non_zero_0)[name = tensor("shape_0")]; tensor slice_by_index_0_begin_0 = const()[name = tensor("slice_by_index_0_begin_0"), val = tensor([0])]; tensor slice_by_index_0_end_0 = const()[name = tensor("slice_by_index_0_end_0"), val = tensor([0])]; tensor slice_by_index_0_squeeze_mask_0 = const()[name = tensor("slice_by_index_0_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_0 = slice_by_index(begin = slice_by_index_0_begin_0, end = slice_by_index_0_end_0, squeeze_mask = slice_by_index_0_squeeze_mask_0, x = shape_0)[name = tensor("slice_by_index_0")]; tensor expand_dims_1_axes_0 = const()[name = tensor("expand_dims_1_axes_0"), val = tensor([0])]; tensor expand_dims_1 = expand_dims(axes = expand_dims_1_axes_0, x = slice_by_index_0)[name = tensor("expand_dims_1")]; tensor tile_0 = tile(reps = expand_dims_1, x = expand_dims_0)[name = tensor("tile_0")]; tensor scatter_nd_0_mode_0 = const()[name = tensor("scatter_nd_0_mode_0"), val = tensor("update")]; tensor scatter_nd_0 = scatter_nd(data = var_289, indices = non_zero_0, mode = scatter_nd_0_mode_0, updates = tile_0)[name = tensor("scatter_nd_0")]; tensor mul_0_y_0 = const()[name = tensor("mul_0_y_0"), val = tensor(0x0p+0)]; tensor mul_0 = mul(x = var_289, y = mul_0_y_0)[name = tensor("mul_0")]; tensor not_equal_1 = not_equal(x = mul_0, y = mul_0)[name = tensor("not_equal_1")]; tensor greater_0_y_0 = const()[name = tensor("greater_0_y_0"), val = tensor(0x0p+0)]; tensor greater_0 = greater(x = var_289, y = greater_0_y_0)[name = tensor("greater_0")]; tensor logical_and_0 = logical_and(x = not_equal_1, y = greater_0)[name = tensor("logical_and_0")]; tensor less_0_y_0 = const()[name = tensor("less_0_y_0"), val = tensor(0x0p+0)]; tensor less_0 = less(x = var_289, y = less_0_y_0)[name = tensor("less_0")]; tensor logical_and_1 = logical_and(x = not_equal_1, y = less_0)[name = tensor("logical_and_1")]; tensor cast_8_dtype_0 = const()[name = tensor("cast_8_dtype_0"), val = tensor("int32")]; tensor cast_8 = cast(dtype = cast_8_dtype_0, x = logical_and_0)[name = tensor("cast_66")]; tensor non_zero_1 = non_zero(x = cast_8)[name = tensor("non_zero_1")]; tensor expand_dims_2 = const()[name = tensor("expand_dims_2"), val = tensor([0x1.fffffep+127])]; tensor shape_1 = shape(x = non_zero_1)[name = tensor("shape_1")]; tensor slice_by_index_1_begin_0 = const()[name = tensor("slice_by_index_1_begin_0"), val = tensor([0])]; tensor slice_by_index_1_end_0 = const()[name = tensor("slice_by_index_1_end_0"), val = tensor([0])]; tensor slice_by_index_1_squeeze_mask_0 = const()[name = tensor("slice_by_index_1_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_1 = slice_by_index(begin = slice_by_index_1_begin_0, end = slice_by_index_1_end_0, squeeze_mask = slice_by_index_1_squeeze_mask_0, x = shape_1)[name = tensor("slice_by_index_1")]; tensor expand_dims_3_axes_0 = const()[name = tensor("expand_dims_3_axes_0"), val = tensor([0])]; tensor expand_dims_3 = expand_dims(axes = expand_dims_3_axes_0, x = slice_by_index_1)[name = tensor("expand_dims_3")]; tensor tile_1 = tile(reps = expand_dims_3, x = expand_dims_2)[name = tensor("tile_1")]; tensor scatter_nd_1_mode_0 = const()[name = tensor("scatter_nd_1_mode_0"), val = tensor("update")]; tensor scatter_nd_1 = scatter_nd(data = scatter_nd_0, indices = non_zero_1, mode = scatter_nd_1_mode_0, updates = tile_1)[name = tensor("scatter_nd_1")]; tensor cast_9_dtype_0 = const()[name = tensor("cast_9_dtype_0"), val = tensor("int32")]; tensor cast_9 = cast(dtype = cast_9_dtype_0, x = logical_and_1)[name = tensor("cast_65")]; tensor non_zero_2 = non_zero(x = cast_9)[name = tensor("non_zero_2")]; tensor expand_dims_4 = const()[name = tensor("expand_dims_4"), val = tensor([-0x1.fffffep+127])]; tensor shape_2 = shape(x = non_zero_2)[name = tensor("shape_2")]; tensor slice_by_index_2_begin_0 = const()[name = tensor("slice_by_index_2_begin_0"), val = tensor([0])]; tensor slice_by_index_2_end_0 = const()[name = tensor("slice_by_index_2_end_0"), val = tensor([0])]; tensor slice_by_index_2_squeeze_mask_0 = const()[name = tensor("slice_by_index_2_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_2 = slice_by_index(begin = slice_by_index_2_begin_0, end = slice_by_index_2_end_0, squeeze_mask = slice_by_index_2_squeeze_mask_0, x = shape_2)[name = tensor("slice_by_index_2")]; tensor expand_dims_5_axes_0 = const()[name = tensor("expand_dims_5_axes_0"), val = tensor([0])]; tensor expand_dims_5 = expand_dims(axes = expand_dims_5_axes_0, x = slice_by_index_2)[name = tensor("expand_dims_5")]; tensor tile_2 = tile(reps = expand_dims_5, x = expand_dims_4)[name = tensor("tile_2")]; tensor scatter_nd_2_mode_0 = const()[name = tensor("scatter_nd_2_mode_0"), val = tensor("update")]; tensor scatter_nd_2 = scatter_nd(data = scatter_nd_1, indices = non_zero_2, mode = scatter_nd_2_mode_0, updates = tile_2)[name = tensor("scatter_nd_2")]; tensor var_294_promoted = const()[name = tensor("op_294_promoted"), val = tensor(0x1p+0)]; tensor var_296 = sub(x = var_294_promoted, y = padded_context_mask)[name = tensor("op_296")]; tensor var_297 = const()[name = tensor("op_297"), val = tensor(0x1.0624dep-10)]; tensor var_298 = mul(x = var_296, y = var_297)[name = tensor("op_298")]; tensor var_300 = add(x = padded_context_mask, y = var_298)[name = tensor("op_300")]; tensor log_padded_epsilon_0 = const()[name = tensor("log_padded_epsilon_0"), val = tensor(0x1p-149)]; tensor log_padded = log(epsilon = log_padded_epsilon_0, x = var_300)[name = tensor("log_padded")]; tensor var_303 = sub(x = context_feedback_9, y = scatter_nd_2)[name = tensor("op_303")]; tensor var_304 = abs(x = var_303)[name = tensor("op_304")]; tensor var_306 = const()[name = tensor("op_306"), val = tensor(0x1.0624dep-10)]; tensor var_307 = add(x = var_304, y = var_306)[name = tensor("op_307")]; tensor var_308_epsilon_0 = const()[name = tensor("op_308_epsilon_0"), val = tensor(0x1p-149)]; tensor var_308 = log(epsilon = var_308_epsilon_0, x = var_307)[name = tensor("op_308")]; tensor var_309_promoted = const()[name = tensor("op_309_promoted"), val = tensor(0x1p+1)]; tensor var_310 = mul(x = var_308, y = var_309_promoted)[name = tensor("op_310")]; tensor x_1 = add(x = var_310, y = log_padded)[name = tensor("x_1")]; tensor reduce_max_0_axes_0 = const()[name = tensor("reduce_max_0_axes_0"), val = tensor([0])]; tensor reduce_max_0_keep_dims_0 = const()[name = tensor("reduce_max_0_keep_dims_0"), val = tensor(false)]; tensor reduce_max_0 = reduce_max(axes = reduce_max_0_axes_0, keep_dims = reduce_max_0_keep_dims_0, x = x_1)[name = tensor("reduce_max_0")]; tensor var_318_axes_0 = const()[name = tensor("op_318_axes_0"), val = tensor([0])]; tensor var_318 = expand_dims(axes = var_318_axes_0, x = reduce_max_0)[name = tensor("op_318")]; tensor var_320 = sub(x = x_1, y = var_318)[name = tensor("op_320")]; tensor var_321 = exp(x = var_320)[name = tensor("op_321")]; tensor var_326_axes_0 = const()[name = tensor("op_326_axes_0"), val = tensor([0])]; tensor var_326_keep_dims_0 = const()[name = tensor("op_326_keep_dims_0"), val = tensor(false)]; tensor var_326 = reduce_sum(axes = var_326_axes_0, keep_dims = var_326_keep_dims_0, x = var_321)[name = tensor("op_326")]; tensor var_327_epsilon_0 = const()[name = tensor("op_327_epsilon_0"), val = tensor(0x1p-149)]; tensor var_327 = log(epsilon = var_327_epsilon_0, x = var_326)[name = tensor("op_327")]; tensor var_329 = add(x = reduce_max_0, y = var_327)[name = tensor("op_329")]; tensor var_333_epsilon_0 = const()[name = tensor("op_333_epsilon_0"), val = tensor(0x1p-149)]; tensor var_333 = log(epsilon = var_333_epsilon_0, x = var_285)[name = tensor("op_333")]; tensor var_335 = sub(x = var_329, y = var_333)[name = tensor("op_335")]; tensor var_336 = const()[name = tensor("op_336"), val = tensor(0x1p-1)]; tensor log_adjusted_std = mul(x = var_335, y = var_336)[name = tensor("log_adjusted_std")]; tensor var_338 = exp(x = log_adjusted_std)[name = tensor("op_338")]; tensor not_equal_2 = not_equal(x = var_338, y = var_338)[name = tensor("not_equal_2")]; tensor cast_10_dtype_0 = const()[name = tensor("cast_10_dtype_0"), val = tensor("int32")]; tensor cast_10 = cast(dtype = cast_10_dtype_0, x = not_equal_2)[name = tensor("cast_64")]; tensor non_zero_3 = non_zero(x = cast_10)[name = tensor("non_zero_3")]; tensor expand_dims_6 = const()[name = tensor("expand_dims_6"), val = tensor([0x1p+0])]; tensor shape_3 = shape(x = non_zero_3)[name = tensor("shape_3")]; tensor slice_by_index_3_begin_0 = const()[name = tensor("slice_by_index_3_begin_0"), val = tensor([0])]; tensor slice_by_index_3_end_0 = const()[name = tensor("slice_by_index_3_end_0"), val = tensor([0])]; tensor slice_by_index_3_squeeze_mask_0 = const()[name = tensor("slice_by_index_3_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_3 = slice_by_index(begin = slice_by_index_3_begin_0, end = slice_by_index_3_end_0, squeeze_mask = slice_by_index_3_squeeze_mask_0, x = shape_3)[name = tensor("slice_by_index_3")]; tensor expand_dims_7_axes_0 = const()[name = tensor("expand_dims_7_axes_0"), val = tensor([0])]; tensor expand_dims_7 = expand_dims(axes = expand_dims_7_axes_0, x = slice_by_index_3)[name = tensor("expand_dims_7")]; tensor tile_3 = tile(reps = expand_dims_7, x = expand_dims_6)[name = tensor("tile_3")]; tensor scatter_nd_3_mode_0 = const()[name = tensor("scatter_nd_3_mode_0"), val = tensor("update")]; tensor scatter_nd_3 = scatter_nd(data = var_338, indices = non_zero_3, mode = scatter_nd_3_mode_0, updates = tile_3)[name = tensor("scatter_nd_3")]; tensor mul_1_y_0 = const()[name = tensor("mul_1_y_0"), val = tensor(0x0p+0)]; tensor mul_1 = mul(x = var_338, y = mul_1_y_0)[name = tensor("mul_1")]; tensor not_equal_3 = not_equal(x = mul_1, y = mul_1)[name = tensor("not_equal_3")]; tensor greater_1_y_0 = const()[name = tensor("greater_1_y_0"), val = tensor(0x0p+0)]; tensor greater_1 = greater(x = var_338, y = greater_1_y_0)[name = tensor("greater_1")]; tensor logical_and_2 = logical_and(x = not_equal_3, y = greater_1)[name = tensor("logical_and_2")]; tensor less_1_y_0 = const()[name = tensor("less_1_y_0"), val = tensor(0x0p+0)]; tensor less_1 = less(x = var_338, y = less_1_y_0)[name = tensor("less_1")]; tensor logical_and_3 = logical_and(x = not_equal_3, y = less_1)[name = tensor("logical_and_3")]; tensor cast_11_dtype_0 = const()[name = tensor("cast_11_dtype_0"), val = tensor("int32")]; tensor cast_11 = cast(dtype = cast_11_dtype_0, x = logical_and_2)[name = tensor("cast_63")]; tensor non_zero_4 = non_zero(x = cast_11)[name = tensor("non_zero_4")]; tensor expand_dims_8 = const()[name = tensor("expand_dims_8"), val = tensor([0x1.fffffep+127])]; tensor shape_4 = shape(x = non_zero_4)[name = tensor("shape_4")]; tensor slice_by_index_4_begin_0 = const()[name = tensor("slice_by_index_4_begin_0"), val = tensor([0])]; tensor slice_by_index_4_end_0 = const()[name = tensor("slice_by_index_4_end_0"), val = tensor([0])]; tensor slice_by_index_4_squeeze_mask_0 = const()[name = tensor("slice_by_index_4_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_4 = slice_by_index(begin = slice_by_index_4_begin_0, end = slice_by_index_4_end_0, squeeze_mask = slice_by_index_4_squeeze_mask_0, x = shape_4)[name = tensor("slice_by_index_4")]; tensor expand_dims_9_axes_0 = const()[name = tensor("expand_dims_9_axes_0"), val = tensor([0])]; tensor expand_dims_9 = expand_dims(axes = expand_dims_9_axes_0, x = slice_by_index_4)[name = tensor("expand_dims_9")]; tensor tile_4 = tile(reps = expand_dims_9, x = expand_dims_8)[name = tensor("tile_4")]; tensor scatter_nd_4_mode_0 = const()[name = tensor("scatter_nd_4_mode_0"), val = tensor("update")]; tensor scatter_nd_4 = scatter_nd(data = scatter_nd_3, indices = non_zero_4, mode = scatter_nd_4_mode_0, updates = tile_4)[name = tensor("scatter_nd_4")]; tensor cast_12_dtype_0 = const()[name = tensor("cast_12_dtype_0"), val = tensor("int32")]; tensor cast_12 = cast(dtype = cast_12_dtype_0, x = logical_and_3)[name = tensor("cast_62")]; tensor non_zero_5 = non_zero(x = cast_12)[name = tensor("non_zero_5")]; tensor expand_dims_10 = const()[name = tensor("expand_dims_10"), val = tensor([-0x1.fffffep+127])]; tensor shape_5 = shape(x = non_zero_5)[name = tensor("shape_5")]; tensor slice_by_index_5_begin_0 = const()[name = tensor("slice_by_index_5_begin_0"), val = tensor([0])]; tensor slice_by_index_5_end_0 = const()[name = tensor("slice_by_index_5_end_0"), val = tensor([0])]; tensor slice_by_index_5_squeeze_mask_0 = const()[name = tensor("slice_by_index_5_squeeze_mask_0"), val = tensor([true])]; tensor slice_by_index_5 = slice_by_index(begin = slice_by_index_5_begin_0, end = slice_by_index_5_end_0, squeeze_mask = slice_by_index_5_squeeze_mask_0, x = shape_5)[name = tensor("slice_by_index_5")]; tensor expand_dims_11_axes_0 = const()[name = tensor("expand_dims_11_axes_0"), val = tensor([0])]; tensor expand_dims_11 = expand_dims(axes = expand_dims_11_axes_0, x = slice_by_index_5)[name = tensor("expand_dims_11")]; tensor tile_5 = tile(reps = expand_dims_11, x = expand_dims_10)[name = tensor("tile_5")]; tensor scatter_nd_5_mode_0 = const()[name = tensor("scatter_nd_5_mode_0"), val = tensor("update")]; tensor scatter_nd_5 = scatter_nd(data = scatter_nd_4, indices = non_zero_5, mode = scatter_nd_5_mode_0, updates = tile_5)[name = tensor("scatter_nd_5")]; tensor var_344 = const()[name = tensor("op_344"), val = tensor(0x1.0c6f7ap-20)]; tensor context_sigma = add(x = scatter_nd_5, y = var_344)[name = tensor("context_sigma")]; tensor var_346 = const()[name = tensor("op_346"), val = tensor(-0x1.19999ap+0)]; tensor var_347 = greater(x = device_context, y = var_346)[name = tensor("op_347")]; tensor var_347_promoted_dtype_0 = const()[name = tensor("op_347_promoted_dtype_0"), val = tensor("fp32")]; tensor var_351 = const()[name = tensor("op_351"), val = tensor(-0x1.ccccccp-1)]; tensor var_352 = less(x = device_context, y = var_351)[name = tensor("op_352")]; tensor var_352_promoted_dtype_0 = const()[name = tensor("op_352_promoted_dtype_0"), val = tensor("fp32")]; tensor var_352_promoted = cast(dtype = var_352_promoted_dtype_0, x = var_352)[name = tensor("cast_60")]; tensor var_347_promoted = cast(dtype = var_347_promoted_dtype_0, x = var_347)[name = tensor("cast_61")]; tensor is_padding_1 = mul(x = var_347_promoted, y = var_352_promoted)[name = tensor("is_padding_1")]; tensor var_357 = const()[name = tensor("op_357"), val = tensor(0x1p+0)]; tensor is_not_padding_1 = sub(x = var_357, y = is_padding_1)[name = tensor("is_not_padding_1")]; tensor var_360 = const()[name = tensor("op_360"), val = tensor(-0x1.e848p+19)]; tensor var_361 = greater(x = context_sigma, y = var_360)[name = tensor("op_361")]; tensor var_361_promoted_dtype_0 = const()[name = tensor("op_361_promoted_dtype_0"), val = tensor("fp32")]; tensor var_361_promoted = cast(dtype = var_361_promoted_dtype_0, x = var_361)[name = tensor("cast_59")]; tensor var_362 = mul(x = is_padding_1, y = var_361_promoted)[name = tensor("op_362")]; tensor var_363 = mul(x = is_not_padding_1, y = context_sigma)[name = tensor("op_363")]; tensor padded_sigma_1 = add(x = var_362, y = var_363)[name = tensor("padded_sigma_1")]; tensor var_366 = mul(x = scatter_nd_2, y = is_not_padding_1)[name = tensor("op_366")]; tensor var_368 = sub(x = device_context, y = var_366)[name = tensor("op_368")]; tensor var_369 = real_div(x = var_368, y = padded_sigma_1)[name = tensor("op_369")]; tensor var_370 = mul(x = is_not_padding_1, y = var_369)[name = tensor("op_370")]; tensor var_371 = mul(x = is_padding_1, y = device_context)[name = tensor("op_371")]; tensor context = add(x = var_370, y = var_371)[name = tensor("context")]; tensor var_376_begin_0 = const()[name = tensor("op_376_begin_0"), val = tensor([0, 0])]; tensor var_376_end_0 = const()[name = tensor("op_376_end_0"), val = tensor([1, 0])]; tensor var_376_end_mask_0 = const()[name = tensor("op_376_end_mask_0"), val = tensor([false, true])]; tensor var_376_squeeze_mask_0 = const()[name = tensor("op_376_squeeze_mask_0"), val = tensor([true, false])]; tensor var_376 = slice_by_index(begin = var_376_begin_0, end = var_376_end_0, end_mask = var_376_end_mask_0, squeeze_mask = var_376_squeeze_mask_0, x = x)[name = tensor("op_376")]; tensor var_377 = const()[name = tensor("op_377"), val = tensor(-0x1.e848p+19)]; tensor var_378 = greater(x = var_376, y = var_377)[name = tensor("op_378")]; tensor var_378_promoted_dtype_0 = const()[name = tensor("op_378_promoted_dtype_0"), val = tensor("int32")]; tensor var_381 = const()[name = tensor("op_381"), val = tensor(-0x1.e848p+19)]; tensor ones_promoted_dtype_0 = const()[name = tensor("ones_promoted_dtype_0"), val = tensor("fp32")]; tensor var_378_to_fp32 = cast(dtype = ones_promoted_dtype_0, x = var_378)[name = tensor("cast_57")]; tensor small = mul(x = var_378_to_fp32, y = var_381)[name = tensor("small")]; tensor var_383 = const()[name = tensor("op_383"), val = tensor(0x1.e848p+19)]; tensor big = mul(x = var_378_to_fp32, y = var_383)[name = tensor("big")]; tensor var_385 = const()[name = tensor("op_385"), val = tensor(0)]; tensor var_378_promoted = cast(dtype = var_378_promoted_dtype_0, x = var_378)[name = tensor("cast_58")]; tensor zeros = mul(x = var_378_promoted, y = var_385)[name = tensor("zeros")]; tensor var_388_axes_0 = const()[name = tensor("op_388_axes_0"), val = tensor([0])]; tensor var_388 = expand_dims(axes = var_388_axes_0, x = small)[name = tensor("op_388")]; tensor var_390_axes_0 = const()[name = tensor("op_390_axes_0"), val = tensor([0])]; tensor var_390 = expand_dims(axes = var_390_axes_0, x = big)[name = tensor("op_390")]; tensor var_392 = const()[name = tensor("op_392"), val = tensor(0)]; tensor x_padded_interleave_0 = const()[name = tensor("x_padded_interleave_0"), val = tensor(false)]; tensor x_padded = concat(axis = var_392, interleave = x_padded_interleave_0, values = (var_388, x, var_390))[name = tensor("x_padded")]; tensor var_394 = const()[name = tensor("op_394"), val = tensor(0)]; tensor logical_not_0 = const()[name = tensor("logical_not_0"), val = tensor(true)]; tensor i = argsort(ascending = logical_not_0, axis = var_394, x = x_padded)[name = tensor("i")]; tensor by_x = gather_along_axis(axis = var_394, indices = i, x = x_padded)[name = tensor("by_x")]; tensor var_402_begin_0 = const()[name = tensor("op_402_begin_0"), val = tensor([1, 0])]; tensor var_402_end_0 = const()[name = tensor("op_402_end_0"), val = tensor([-1, 0])]; tensor var_402_end_mask_0 = const()[name = tensor("op_402_end_mask_0"), val = tensor([false, true])]; tensor var_402 = slice_by_index(begin = var_402_begin_0, end = var_402_end_0, end_mask = var_402_end_mask_0, x = by_x)[name = tensor("op_402")]; tensor var_407_begin_0 = const()[name = tensor("op_407_begin_0"), val = tensor([0, 0])]; tensor var_407_end_0 = const()[name = tensor("op_407_end_0"), val = tensor([-2, 0])]; tensor var_407_end_mask_0 = const()[name = tensor("op_407_end_mask_0"), val = tensor([false, true])]; tensor var_407 = slice_by_index(begin = var_407_begin_0, end = var_407_end_0, end_mask = var_407_end_mask_0, x = by_x)[name = tensor("op_407")]; tensor var_409 = sub(x = var_402, y = var_407)[name = tensor("op_409")]; tensor var_410_promoted = const()[name = tensor("op_410_promoted"), val = tensor(0x0p+0)]; tensor var_411 = greater(x = var_409, y = var_410_promoted)[name = tensor("op_411")]; tensor var_411_promoted_dtype_0 = const()[name = tensor("op_411_promoted_dtype_0"), val = tensor("int32")]; tensor var_415_axes_0 = const()[name = tensor("op_415_axes_0"), val = tensor([0])]; tensor var_415 = expand_dims(axes = var_415_axes_0, x = zeros)[name = tensor("op_415")]; tensor var_419 = const()[name = tensor("op_419"), val = tensor(0)]; tensor mask_5_interleave_0 = const()[name = tensor("mask_5_interleave_0"), val = tensor(false)]; tensor var_411_promoted = cast(dtype = var_411_promoted_dtype_0, x = var_411)[name = tensor("cast_56")]; tensor mask_5 = concat(axis = var_419, interleave = mask_5_interleave_0, values = (var_415, var_411_promoted, var_415))[name = tensor("mask_5")]; tensor mask_5_promoted_dtype_0 = const()[name = tensor("mask_5_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_5_promoted = cast(dtype = mask_5_promoted_dtype_0, x = mask_5)[name = tensor("cast_55")]; tensor var_421 = mul(x = by_x, y = mask_5_promoted)[name = tensor("op_421")]; tensor var_422 = const()[name = tensor("op_422"), val = tensor(0)]; tensor logical_not_1 = const()[name = tensor("logical_not_1"), val = tensor(true)]; tensor var_424 = argsort(ascending = logical_not_1, axis = var_422, x = i)[name = tensor("op_424")]; tensor var_425 = const()[name = tensor("op_425"), val = tensor(0)]; tensor unique = gather_along_axis(axis = var_425, indices = var_424, x = var_421)[name = tensor("unique")]; tensor var_432_begin_0 = const()[name = tensor("op_432_begin_0"), val = tensor([1, 0])]; tensor var_432_end_0 = const()[name = tensor("op_432_end_0"), val = tensor([-1, 0])]; tensor var_432_end_mask_0 = const()[name = tensor("op_432_end_mask_0"), val = tensor([false, true])]; tensor var_432 = slice_by_index(begin = var_432_begin_0, end = var_432_end_0, end_mask = var_432_end_mask_0, x = unique)[name = tensor("op_432")]; tensor var_446_begin_0 = const()[name = tensor("op_446_begin_0"), val = tensor([1, 0])]; tensor var_446_end_0 = const()[name = tensor("op_446_end_0"), val = tensor([2, 0])]; tensor var_446_end_mask_0 = const()[name = tensor("op_446_end_mask_0"), val = tensor([false, true])]; tensor var_446 = slice_by_index(begin = var_446_begin_0, end = var_446_end_0, end_mask = var_446_end_mask_0, x = candidate_interactions_transpose)[name = tensor("op_446")]; tensor alignment_feedback_1 = squeeze(x = var_446)[name = tensor("alignment_feedback_1")]; tensor var_469 = const()[name = tensor("op_469"), val = tensor(-0x1.8p+0)]; tensor var_470 = greater(x = context_feedback_5, y = var_469)[name = tensor("op_470")]; tensor var_470_promoted_dtype_0 = const()[name = tensor("op_470_promoted_dtype_0"), val = tensor("fp32")]; tensor var_474 = const()[name = tensor("op_474"), val = tensor(-0x1p-1)]; tensor var_475 = less(x = context_feedback_5, y = var_474)[name = tensor("op_475")]; tensor var_475_promoted_dtype_0 = const()[name = tensor("op_475_promoted_dtype_0"), val = tensor("fp32")]; tensor var_475_promoted = cast(dtype = var_475_promoted_dtype_0, x = var_475)[name = tensor("cast_53")]; tensor var_470_promoted = cast(dtype = var_470_promoted_dtype_0, x = var_470)[name = tensor("cast_54")]; tensor var_479 = mul(x = var_470_promoted, y = var_475_promoted)[name = tensor("op_479")]; tensor var_480_promoted = const()[name = tensor("op_480_promoted"), val = tensor(0x1p+0)]; tensor not_padded_feedback = sub(x = var_480_promoted, y = var_479)[name = tensor("not_padded_feedback")]; tensor var_483 = const()[name = tensor("op_483"), val = tensor(-0x1.19999ap+0)]; tensor var_484 = greater(x = context_feedback_5, y = var_483)[name = tensor("op_484")]; tensor var_484_promoted_dtype_0 = const()[name = tensor("op_484_promoted_dtype_0"), val = tensor("fp32")]; tensor var_488 = const()[name = tensor("op_488"), val = tensor(-0x1.ccccccp-1)]; tensor var_489 = less(x = context_feedback_5, y = var_488)[name = tensor("op_489")]; tensor var_489_promoted_dtype_0 = const()[name = tensor("op_489_promoted_dtype_0"), val = tensor("fp32")]; tensor var_489_promoted = cast(dtype = var_489_promoted_dtype_0, x = var_489)[name = tensor("cast_51")]; tensor var_484_promoted = cast(dtype = var_484_promoted_dtype_0, x = var_484)[name = tensor("cast_52")]; tensor is_padding = mul(x = var_484_promoted, y = var_489_promoted)[name = tensor("is_padding")]; tensor var_494 = const()[name = tensor("op_494"), val = tensor(0x1p+0)]; tensor is_not_padding = sub(x = var_494, y = is_padding)[name = tensor("is_not_padding")]; tensor var_499 = mul(x = is_padding, y = var_361_promoted)[name = tensor("op_499")]; tensor var_500 = mul(x = is_not_padding, y = context_sigma)[name = tensor("op_500")]; tensor padded_sigma = add(x = var_499, y = var_500)[name = tensor("padded_sigma")]; tensor var_503 = mul(x = scatter_nd_2, y = is_not_padding)[name = tensor("op_503")]; tensor var_505 = sub(x = context_feedback_5, y = var_503)[name = tensor("op_505")]; tensor var_506 = real_div(x = var_505, y = padded_sigma)[name = tensor("op_506")]; tensor var_507 = mul(x = is_not_padding, y = var_506)[name = tensor("op_507")]; tensor var_508 = mul(x = is_padding, y = context_feedback_5)[name = tensor("op_508")]; tensor context_feedback = add(x = var_507, y = var_508)[name = tensor("context_feedback")]; tensor var_511 = const()[name = tensor("op_511"), val = tensor(-0x1p-1)]; tensor var_512 = greater(x = alignment_feedback_1, y = var_511)[name = tensor("op_512")]; tensor var_512_promoted_dtype_0 = const()[name = tensor("op_512_promoted_dtype_0"), val = tensor("fp32")]; tensor var_517 = const()[name = tensor("op_517"), val = tensor(0x1p-1)]; tensor var_518 = sub(x = alignment_feedback_1, y = var_517)[name = tensor("op_518")]; tensor var_512_promoted = cast(dtype = var_512_promoted_dtype_0, x = var_512)[name = tensor("cast_50")]; tensor var_519 = mul(x = var_512_promoted, y = var_518)[name = tensor("op_519")]; tensor var_520_promoted = const()[name = tensor("op_520_promoted"), val = tensor(0x1p+1)]; tensor alignment_feedback = mul(x = var_519, y = var_520_promoted)[name = tensor("alignment_feedback")]; tensor time_context_begin_0 = const()[name = tensor("time_context_begin_0"), val = tensor([0])]; tensor time_context_end_0 = const()[name = tensor("time_context_end_0"), val = tensor([1])]; tensor time_context_end_mask_0 = const()[name = tensor("time_context_end_mask_0"), val = tensor([false])]; tensor time_context_squeeze_mask_0 = const()[name = tensor("time_context_squeeze_mask_0"), val = tensor([true])]; tensor time_context = slice_by_index(begin = time_context_begin_0, end = time_context_end_0, end_mask = time_context_end_mask_0, squeeze_mask = time_context_squeeze_mask_0, x = context)[name = tensor("time_context")]; tensor location_context_begin_0 = const()[name = tensor("location_context_begin_0"), val = tensor([1])]; tensor location_context_end_0 = const()[name = tensor("location_context_end_0"), val = tensor([4])]; tensor location_context_end_mask_0 = const()[name = tensor("location_context_end_mask_0"), val = tensor([false])]; tensor location_context = slice_by_index(begin = location_context_begin_0, end = location_context_end_0, end_mask = location_context_end_mask_0, x = context)[name = tensor("location_context")]; tensor freq_context_begin_0 = const()[name = tensor("freq_context_begin_0"), val = tensor([4])]; tensor freq_context_end_0 = const()[name = tensor("freq_context_end_0"), val = tensor([6])]; tensor freq_context_end_mask_0 = const()[name = tensor("freq_context_end_mask_0"), val = tensor([true])]; tensor freq_context = slice_by_index(begin = freq_context_begin_0, end = freq_context_end_0, end_mask = freq_context_end_mask_0, x = context)[name = tensor("freq_context")]; tensor var_537_perm_0 = const()[name = tensor("op_537_perm_0"), val = tensor([1, 0])]; tensor time_context_feedback_begin_0 = const()[name = tensor("time_context_feedback_begin_0"), val = tensor([0, 0])]; tensor time_context_feedback_end_0 = const()[name = tensor("time_context_feedback_end_0"), val = tensor([1, 0])]; tensor time_context_feedback_end_mask_0 = const()[name = tensor("time_context_feedback_end_mask_0"), val = tensor([false, true])]; tensor time_context_feedback_squeeze_mask_0 = const()[name = tensor("time_context_feedback_squeeze_mask_0"), val = tensor([true, false])]; tensor var_537 = transpose(perm = var_537_perm_0, x = context_feedback)[name = tensor("transpose_6")]; tensor time_context_feedback = slice_by_index(begin = time_context_feedback_begin_0, end = time_context_feedback_end_0, end_mask = time_context_feedback_end_mask_0, squeeze_mask = time_context_feedback_squeeze_mask_0, x = var_537)[name = tensor("time_context_feedback")]; tensor var_543_perm_0 = const()[name = tensor("op_543_perm_0"), val = tensor([1, 0])]; tensor not_padded_time_begin_0 = const()[name = tensor("not_padded_time_begin_0"), val = tensor([0, 0])]; tensor not_padded_time_end_0 = const()[name = tensor("not_padded_time_end_0"), val = tensor([1, 0])]; tensor not_padded_time_end_mask_0 = const()[name = tensor("not_padded_time_end_mask_0"), val = tensor([false, true])]; tensor not_padded_time_squeeze_mask_0 = const()[name = tensor("not_padded_time_squeeze_mask_0"), val = tensor([true, false])]; tensor var_543 = transpose(perm = var_543_perm_0, x = not_padded_feedback)[name = tensor("transpose_5")]; tensor not_padded_time = slice_by_index(begin = not_padded_time_begin_0, end = not_padded_time_end_0, end_mask = not_padded_time_end_mask_0, squeeze_mask = not_padded_time_squeeze_mask_0, x = var_543)[name = tensor("not_padded_time")]; tensor var_554_begin_0 = const()[name = tensor("op_554_begin_0"), val = tensor([1, 0])]; tensor var_554_end_0 = const()[name = tensor("op_554_end_0"), val = tensor([4, 0])]; tensor var_554_end_mask_0 = const()[name = tensor("op_554_end_mask_0"), val = tensor([false, true])]; tensor var_554 = slice_by_index(begin = var_554_begin_0, end = var_554_end_0, end_mask = var_554_end_mask_0, x = var_537)[name = tensor("op_554")]; tensor location_context_feedback_perm_0 = const()[name = tensor("location_context_feedback_perm_0"), val = tensor([1, 0])]; tensor var_565_begin_0 = const()[name = tensor("op_565_begin_0"), val = tensor([1, 0])]; tensor var_565_end_0 = const()[name = tensor("op_565_end_0"), val = tensor([4, 0])]; tensor var_565_end_mask_0 = const()[name = tensor("op_565_end_mask_0"), val = tensor([false, true])]; tensor var_565 = slice_by_index(begin = var_565_begin_0, end = var_565_end_0, end_mask = var_565_end_mask_0, x = var_543)[name = tensor("op_565")]; tensor not_padded_location_perm_0 = const()[name = tensor("not_padded_location_perm_0"), val = tensor([1, 0])]; tensor var_576_begin_0 = const()[name = tensor("op_576_begin_0"), val = tensor([4, 0])]; tensor var_576_end_0 = const()[name = tensor("op_576_end_0"), val = tensor([6, 0])]; tensor var_576_end_mask_0 = const()[name = tensor("op_576_end_mask_0"), val = tensor([true, true])]; tensor var_576 = slice_by_index(begin = var_576_begin_0, end = var_576_end_0, end_mask = var_576_end_mask_0, x = var_537)[name = tensor("op_576")]; tensor freq_context_feedback_perm_0 = const()[name = tensor("freq_context_feedback_perm_0"), val = tensor([1, 0])]; tensor var_587_begin_0 = const()[name = tensor("op_587_begin_0"), val = tensor([4, 0])]; tensor var_587_end_0 = const()[name = tensor("op_587_end_0"), val = tensor([6, 0])]; tensor var_587_end_mask_0 = const()[name = tensor("op_587_end_mask_0"), val = tensor([true, true])]; tensor var_587 = slice_by_index(begin = var_587_begin_0, end = var_587_end_0, end_mask = var_587_end_mask_0, x = var_543)[name = tensor("op_587")]; tensor not_padded_freq_perm_0 = const()[name = tensor("not_padded_freq_perm_0"), val = tensor([1, 0])]; tensor var_592 = sub(x = time_context_feedback, y = time_context)[name = tensor("op_592")]; tensor var_593 = abs(x = var_592)[name = tensor("op_593")]; tensor similarity_time = mul(x = var_593, y = not_padded_time)[name = tensor("similarity_time")]; tensor freq_context_feedback = transpose(perm = freq_context_feedback_perm_0, x = var_576)[name = tensor("transpose_2")]; tensor var_596 = sub(x = freq_context_feedback, y = freq_context)[name = tensor("op_596")]; tensor not_padded_freq = transpose(perm = not_padded_freq_perm_0, x = var_587)[name = tensor("transpose_1")]; tensor input_1 = mul(x = var_596, y = not_padded_freq)[name = tensor("input_1")]; tensor var_600 = const()[name = tensor("op_600"), val = tensor([1])]; tensor var_601 = const()[name = tensor("op_601"), val = tensor(false)]; tensor similarity_freq = reduce_l2_norm(axes = var_600, keep_dims = var_601, x = input_1)[name = tensor("similarity_freq")]; tensor location_context_feedback = transpose(perm = location_context_feedback_perm_0, x = var_554)[name = tensor("transpose_4")]; tensor var_605 = sub(x = location_context_feedback, y = location_context)[name = tensor("op_605")]; tensor not_padded_location = transpose(perm = not_padded_location_perm_0, x = var_565)[name = tensor("transpose_3")]; tensor input = mul(x = var_605, y = not_padded_location)[name = tensor("input")]; tensor var_609 = const()[name = tensor("op_609"), val = tensor([1])]; tensor var_610 = const()[name = tensor("op_610"), val = tensor(false)]; tensor similarity_location = reduce_l2_norm(axes = var_609, keep_dims = var_610, x = input)[name = tensor("similarity_location")]; tensor var_613 = const()[name = tensor("op_613"), val = tensor(0x1p-1)]; tensor var_614 = equal(x = candidate_interactions, y = var_613)[name = tensor("op_614")]; tensor var_614_promoted_dtype_0 = const()[name = tensor("op_614_promoted_dtype_0"), val = tensor("fp32")]; tensor var_621_axes_0 = const()[name = tensor("op_621_axes_0"), val = tensor([1])]; tensor var_621_keep_dims_0 = const()[name = tensor("op_621_keep_dims_0"), val = tensor(false)]; tensor var_614_promoted = cast(dtype = var_614_promoted_dtype_0, x = var_614)[name = tensor("cast_49")]; tensor var_621 = reduce_sum(axes = var_621_axes_0, keep_dims = var_621_keep_dims_0, x = var_614_promoted)[name = tensor("op_621")]; tensor var_622_promoted = const()[name = tensor("op_622_promoted"), val = tensor(0x1p+0)]; tensor var_624 = sub(x = var_622_promoted, y = var_621)[name = tensor("op_624")]; tensor var_625 = mul(x = not_padded_time, y = var_624)[name = tensor("op_625")]; tensor n_time_axes_0 = const()[name = tensor("n_time_axes_0"), val = tensor([0])]; tensor n_time_keep_dims_0 = const()[name = tensor("n_time_keep_dims_0"), val = tensor(false)]; tensor n_time = reduce_sum(axes = n_time_axes_0, keep_dims = n_time_keep_dims_0, x = var_625)[name = tensor("n_time")]; tensor var_631 = const()[name = tensor("op_631"), val = tensor(0x1.342bf4p-1)]; tensor var_632 = pow(x = n_time, y = var_631)[name = tensor("op_632")]; tensor var_633_epsilon_0 = const()[name = tensor("op_633_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_633 = inverse(epsilon = var_633_epsilon_0, x = var_632)[name = tensor("op_633")]; tensor var_634 = const()[name = tensor("op_634"), val = tensor(0x1.2f1c7ep+4)]; tensor bw_time = mul(x = var_633, y = var_634)[name = tensor("bw_time")]; tensor var_640_axes_0 = const()[name = tensor("op_640_axes_0"), val = tensor([1])]; tensor var_640_keep_dims_0 = const()[name = tensor("op_640_keep_dims_0"), val = tensor(false)]; tensor var_640 = reduce_sum(axes = var_640_axes_0, keep_dims = var_640_keep_dims_0, x = not_padded_freq)[name = tensor("op_640")]; tensor var_641 = const()[name = tensor("op_641"), val = tensor(0x0p+0)]; tensor var_642 = greater(x = var_640, y = var_641)[name = tensor("op_642")]; tensor var_642_promoted_dtype_0 = const()[name = tensor("op_642_promoted_dtype_0"), val = tensor("fp32")]; tensor var_642_promoted = cast(dtype = var_642_promoted_dtype_0, x = var_642)[name = tensor("cast_48")]; tensor var_655 = mul(x = var_642_promoted, y = var_624)[name = tensor("op_655")]; tensor n_freq_keep_dims_0 = const()[name = tensor("n_freq_keep_dims_0"), val = tensor(false)]; tensor n_freq = reduce_sum(keep_dims = n_freq_keep_dims_0, x = var_655)[name = tensor("n_freq")]; tensor var_658 = const()[name = tensor("op_658"), val = tensor(0x1.9a9c1ap-2)]; tensor var_659 = pow(x = n_freq, y = var_658)[name = tensor("op_659")]; tensor var_660_epsilon_0 = const()[name = tensor("op_660_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_660 = inverse(epsilon = var_660_epsilon_0, x = var_659)[name = tensor("op_660")]; tensor var_661 = const()[name = tensor("op_661"), val = tensor(0x1.9a7adp+0)]; tensor bw_freq = mul(x = var_660, y = var_661)[name = tensor("bw_freq")]; tensor var_667_axes_0 = const()[name = tensor("op_667_axes_0"), val = tensor([1])]; tensor var_667_keep_dims_0 = const()[name = tensor("op_667_keep_dims_0"), val = tensor(false)]; tensor var_667 = reduce_sum(axes = var_667_axes_0, keep_dims = var_667_keep_dims_0, x = not_padded_location)[name = tensor("op_667")]; tensor var_668 = const()[name = tensor("op_668"), val = tensor(0x0p+0)]; tensor var_669 = greater(x = var_667, y = var_668)[name = tensor("op_669")]; tensor var_669_promoted_dtype_0 = const()[name = tensor("op_669_promoted_dtype_0"), val = tensor("fp32")]; tensor var_669_promoted = cast(dtype = var_669_promoted_dtype_0, x = var_669)[name = tensor("cast_47")]; tensor var_682 = mul(x = var_669_promoted, y = var_624)[name = tensor("op_682")]; tensor n_location_keep_dims_0 = const()[name = tensor("n_location_keep_dims_0"), val = tensor(false)]; tensor n_location = reduce_sum(keep_dims = n_location_keep_dims_0, x = var_682)[name = tensor("n_location")]; tensor var_685 = const()[name = tensor("op_685"), val = tensor(0x1.37937ap-3)]; tensor var_686 = pow(x = n_location, y = var_685)[name = tensor("op_686")]; tensor var_687_epsilon_0 = const()[name = tensor("op_687_epsilon_0"), val = tensor(0x1.a36e2ep-14)]; tensor var_687 = inverse(epsilon = var_687_epsilon_0, x = var_686)[name = tensor("op_687")]; tensor var_688 = const()[name = tensor("op_688"), val = tensor(0x1.9596e6p+1)]; tensor bw_location = mul(x = var_687, y = var_688)[name = tensor("bw_location")]; tensor var_691_axes_0 = const()[name = tensor("op_691_axes_0"), val = tensor([-1])]; tensor var_691 = expand_dims(axes = var_691_axes_0, x = var_200)[name = tensor("op_691")]; tensor var_693_axes_0 = const()[name = tensor("op_693_axes_0"), val = tensor([-1])]; tensor var_693 = expand_dims(axes = var_693_axes_0, x = var_691)[name = tensor("op_693")]; tensor var_694_promoted = const()[name = tensor("op_694_promoted"), val = tensor(-0x1.f4p+9)]; tensor var_695 = greater(x = x, y = var_694_promoted)[name = tensor("op_695")]; tensor var_695_promoted_dtype_0 = const()[name = tensor("op_695_promoted_dtype_0"), val = tensor("int32")]; tensor var_699_axes_0 = const()[name = tensor("op_699_axes_0"), val = tensor([0])]; tensor var_695_promoted = cast(dtype = var_695_promoted_dtype_0, x = var_695)[name = tensor("cast_46")]; tensor var_699 = expand_dims(axes = var_699_axes_0, x = var_695_promoted)[name = tensor("op_699")]; tensor var_699_promoted_dtype_0 = const()[name = tensor("op_699_promoted_dtype_0"), val = tensor("fp32")]; tensor var_699_promoted = cast(dtype = var_699_promoted_dtype_0, x = var_699)[name = tensor("cast_45")]; tensor expanded = mul(x = var_693, y = var_699_promoted)[name = tensor("expanded")]; tensor pos_align_candidates = mul(x = var_200, y = alignment_feedback)[name = tensor("pos_align_candidates")]; tensor var_702_promoted = const()[name = tensor("op_702_promoted"), val = tensor(0x0p+0)]; tensor var_703 = greater(x = alignment_feedback, y = var_702_promoted)[name = tensor("op_703")]; tensor var_703_promoted_dtype_0 = const()[name = tensor("op_703_promoted_dtype_0"), val = tensor("fp32")]; tensor var_703_promoted = cast(dtype = var_703_promoted_dtype_0, x = var_703)[name = tensor("cast_44")]; tensor alignment_scaling_1 = mul(x = var_703_promoted, y = alignment_feedback)[name = tensor("alignment_scaling_1")]; tensor var_705 = real_div(x = similarity_location, y = bw_location)[name = tensor("op_705")]; tensor var_706_promoted = const()[name = tensor("op_706_promoted"), val = tensor(0x1p+1)]; tensor var_707 = pow(x = var_705, y = var_706_promoted)[name = tensor("op_707")]; tensor var_708_promoted = const()[name = tensor("op_708_promoted"), val = tensor(-0x1p+0)]; tensor var_709 = mul(x = var_707, y = var_708_promoted)[name = tensor("op_709")]; tensor location_score_1 = exp(x = var_709)[name = tensor("location_score_1")]; tensor var_711 = real_div(x = similarity_time, y = bw_time)[name = tensor("op_711")]; tensor var_712_promoted = const()[name = tensor("op_712_promoted"), val = tensor(0x1p+1)]; tensor var_713 = pow(x = var_711, y = var_712_promoted)[name = tensor("op_713")]; tensor var_714_promoted = const()[name = tensor("op_714_promoted"), val = tensor(-0x1p+0)]; tensor var_715 = mul(x = var_713, y = var_714_promoted)[name = tensor("op_715")]; tensor time_score_1 = exp(x = var_715)[name = tensor("time_score_1")]; tensor var_717 = real_div(x = similarity_freq, y = bw_freq)[name = tensor("op_717")]; tensor var_718_promoted = const()[name = tensor("op_718_promoted"), val = tensor(0x1p+1)]; tensor var_719 = pow(x = var_717, y = var_718_promoted)[name = tensor("op_719")]; tensor var_720_promoted = const()[name = tensor("op_720_promoted"), val = tensor(-0x1p+0)]; tensor var_721 = mul(x = var_719, y = var_720_promoted)[name = tensor("op_721")]; tensor freq_score_1 = exp(x = var_721)[name = tensor("freq_score_1")]; tensor var_723 = mul(x = alignment_scaling_1, y = time_score_1)[name = tensor("op_723")]; tensor var_724 = mul(x = var_723, y = freq_score_1)[name = tensor("op_724")]; tensor candidate_psuedo_counts_1 = mul(x = var_724, y = location_score_1)[name = tensor("candidate_psuedo_counts_1")]; tensor var_727_axes_0 = const()[name = tensor("op_727_axes_0"), val = tensor([-1])]; tensor var_727 = expand_dims(axes = var_727_axes_0, x = pos_align_candidates)[name = tensor("op_727")]; tensor var_729_axes_0 = const()[name = tensor("op_729_axes_0"), val = tensor([-1])]; tensor var_729 = expand_dims(axes = var_729_axes_0, x = var_727)[name = tensor("op_729")]; tensor var_736 = mul(x = var_729, y = var_699_promoted)[name = tensor("op_736")]; tensor var_737_promoted = const()[name = tensor("op_737_promoted"), val = tensor(0x0p+0)]; tensor mask_7 = greater(x = var_736, y = var_737_promoted)[name = tensor("mask_7")]; tensor mask_7_promoted_dtype_0 = const()[name = tensor("mask_7_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_7_promoted = cast(dtype = mask_7_promoted_dtype_0, x = mask_7)[name = tensor("cast_43")]; tensor var_739 = mul(x = expanded, y = mask_7_promoted)[name = tensor("op_739")]; tensor var_740 = equal(x = var_739, y = x)[name = tensor("op_740")]; tensor var_744_axes_0 = const()[name = tensor("op_744_axes_0"), val = tensor([-1])]; tensor var_744 = expand_dims(axes = var_744_axes_0, x = candidate_psuedo_counts_1)[name = tensor("op_744")]; tensor var_746_axes_0 = const()[name = tensor("op_746_axes_0"), val = tensor([-1])]; tensor var_746 = expand_dims(axes = var_746_axes_0, x = var_744)[name = tensor("op_746")]; tensor expanded_counts_1 = mul(x = var_746, y = var_699_promoted)[name = tensor("expanded_counts_1")]; tensor match_1_promoted_dtype_0 = const()[name = tensor("match_1_promoted_dtype_0"), val = tensor("fp32")]; tensor var_740_to_fp32 = cast(dtype = match_1_promoted_dtype_0, x = var_740)[name = tensor("cast_42")]; tensor var_754 = mul(x = expanded_counts_1, y = var_740_to_fp32)[name = tensor("op_754")]; tensor positive_counts_1_axes_0 = const()[name = tensor("positive_counts_1_axes_0"), val = tensor([0])]; tensor positive_counts_1_keep_dims_0 = const()[name = tensor("positive_counts_1_keep_dims_0"), val = tensor(false)]; tensor positive_counts_1 = reduce_sum(axes = positive_counts_1_axes_0, keep_dims = positive_counts_1_keep_dims_0, x = var_754)[name = tensor("positive_counts_1")]; tensor var_760_promoted = const()[name = tensor("op_760_promoted"), val = tensor(0x0p+0)]; tensor var_761 = less(x = alignment_feedback, y = var_760_promoted)[name = tensor("op_761")]; tensor var_761_promoted_dtype_0 = const()[name = tensor("op_761_promoted_dtype_0"), val = tensor("fp32")]; tensor var_761_promoted = cast(dtype = var_761_promoted_dtype_0, x = var_761)[name = tensor("cast_41")]; tensor alignment_scaling = mul(x = var_761_promoted, y = alignment_feedback)[name = tensor("alignment_scaling")]; tensor var_781 = mul(x = alignment_scaling, y = time_score_1)[name = tensor("op_781")]; tensor var_782 = mul(x = var_781, y = freq_score_1)[name = tensor("op_782")]; tensor var_783 = mul(x = var_782, y = location_score_1)[name = tensor("op_783")]; tensor candidate_psuedo_counts = abs(x = var_783)[name = tensor("candidate_psuedo_counts")]; tensor var_796_promoted = const()[name = tensor("op_796_promoted"), val = tensor(0x0p+0)]; tensor mask = less(x = var_736, y = var_796_promoted)[name = tensor("mask")]; tensor mask_promoted_dtype_0 = const()[name = tensor("mask_promoted_dtype_0"), val = tensor("fp32")]; tensor mask_promoted = cast(dtype = mask_promoted_dtype_0, x = mask)[name = tensor("cast_40")]; tensor var_798 = mul(x = expanded, y = mask_promoted)[name = tensor("op_798")]; tensor var_799 = equal(x = var_798, y = x)[name = tensor("op_799")]; tensor var_803_axes_0 = const()[name = tensor("op_803_axes_0"), val = tensor([-1])]; tensor var_803 = expand_dims(axes = var_803_axes_0, x = candidate_psuedo_counts)[name = tensor("op_803")]; tensor var_805_axes_0 = const()[name = tensor("op_805_axes_0"), val = tensor([-1])]; tensor var_805 = expand_dims(axes = var_805_axes_0, x = var_803)[name = tensor("op_805")]; tensor expanded_counts = mul(x = var_805, y = var_699_promoted)[name = tensor("expanded_counts")]; tensor match_promoted_dtype_0 = const()[name = tensor("match_promoted_dtype_0"), val = tensor("fp32")]; tensor var_799_to_fp32 = cast(dtype = match_promoted_dtype_0, x = var_799)[name = tensor("cast_39")]; tensor var_813 = mul(x = expanded_counts, y = var_799_to_fp32)[name = tensor("op_813")]; tensor var_818_axes_0 = const()[name = tensor("op_818_axes_0"), val = tensor([0])]; tensor var_818_keep_dims_0 = const()[name = tensor("op_818_keep_dims_0"), val = tensor(false)]; tensor var_818 = reduce_sum(axes = var_818_axes_0, keep_dims = var_818_keep_dims_0, x = var_813)[name = tensor("op_818")]; tensor var_819 = const()[name = tensor("op_819"), val = tensor(0x1.ce916p+3)]; tensor negative_counts = mul(x = var_818, y = var_819)[name = tensor("negative_counts")]; tensor var_821_promoted = const()[name = tensor("op_821_promoted"), val = tensor(0x0p+0)]; tensor var_822 = greater(x = var_432, y = var_821_promoted)[name = tensor("op_822")]; tensor var_822_promoted_dtype_0 = const()[name = tensor("op_822_promoted_dtype_0"), val = tensor("fp32")]; tensor var_822_promoted = cast(dtype = var_822_promoted_dtype_0, x = var_822)[name = tensor("cast_38")]; tensor var_826 = mul(x = negative_counts, y = var_822_promoted)[name = tensor("op_826")]; tensor inversion_point_axes_0 = const()[name = tensor("inversion_point_axes_0"), val = tensor([0])]; tensor inversion_point_keep_dims_0 = const()[name = tensor("inversion_point_keep_dims_0"), val = tensor(false)]; tensor inversion_point = reduce_sum(axes = inversion_point_axes_0, keep_dims = inversion_point_keep_dims_0, x = var_826)[name = tensor("inversion_point")]; tensor n_unique_axes_0 = const()[name = tensor("n_unique_axes_0"), val = tensor([0])]; tensor n_unique_keep_dims_0 = const()[name = tensor("n_unique_keep_dims_0"), val = tensor(false)]; tensor n_unique = reduce_sum(axes = n_unique_axes_0, keep_dims = n_unique_keep_dims_0, x = var_822_promoted)[name = tensor("n_unique")]; tensor positive_others_1 = sub(x = inversion_point, y = negative_counts)[name = tensor("positive_others_1")]; tensor var_840_promoted = const()[name = tensor("op_840_promoted"), val = tensor(0x1p+0)]; tensor var_841 = sub(x = n_unique, y = var_840_promoted)[name = tensor("op_841")]; tensor var_843 = const()[name = tensor("op_843"), val = tensor(0x1.0624dep-10)]; tensor var_844 = add(x = var_841, y = var_843)[name = tensor("op_844")]; tensor var_845 = real_div(x = positive_others_1, y = var_844)[name = tensor("op_845")]; tensor var_846 = const()[name = tensor("op_846"), val = tensor(0x1.2f308ep-3)]; tensor positive_others = mul(x = var_845, y = var_846)[name = tensor("positive_others")]; tensor positive_counts = add(x = positive_counts_1, y = positive_others)[name = tensor("positive_counts")]; tensor var_850_promoted = const()[name = tensor("op_850_promoted"), val = tensor(-0x1p+0)]; tensor var_851 = mul(x = negative_counts, y = var_850_promoted)[name = tensor("op_851")]; tensor var_853 = const()[name = tensor("op_853"), val = tensor(0x1.0624dep-10)]; tensor var_854 = add(x = positive_counts, y = var_853)[name = tensor("op_854")]; tensor var_855 = real_div(x = var_851, y = var_854)[name = tensor("op_855")]; tensor var_856 = exp(x = var_855)[name = tensor("op_856")]; tensor var_857_promoted = const()[name = tensor("op_857_promoted"), val = tensor(0x1p+0)]; tensor var_859 = sub(x = var_857_promoted, y = var_856)[name = tensor("op_859")]; tensor to_subtract_1 = mul(x = var_51_to_fp32, y = var_859)[name = tensor("to_subtract_1")]; tensor to_subtract = mul(x = positive_counts, y = to_subtract_1)[name = tensor("to_subtract")]; tensor var_863 = sub(x = positive_others, y = to_subtract)[name = tensor("op_863")]; tensor var_864 = mul(x = var_51_to_fp32, y = var_863)[name = tensor("op_864")]; tensor search_likelihood_begin_0 = const()[name = tensor("search_likelihood_begin_0"), val = tensor([7, 0, 0])]; tensor search_likelihood_end_0 = const()[name = tensor("search_likelihood_end_0"), val = tensor([8, 0, 0])]; tensor search_likelihood_end_mask_0 = const()[name = tensor("search_likelihood_end_mask_0"), val = tensor([false, true, true])]; tensor search_likelihood_squeeze_mask_0 = const()[name = tensor("search_likelihood_squeeze_mask_0"), val = tensor([true, false, false])]; tensor search_likelihood = slice_by_index(begin = search_likelihood_begin_0, end = search_likelihood_end_0, end_mask = search_likelihood_end_mask_0, squeeze_mask = search_likelihood_squeeze_mask_0, x = tuples)[name = tensor("search_likelihood")]; tensor var_868 = const()[name = tensor("op_868"), val = tensor(0x1p-1)]; tensor var_869 = greater(x = var_432, y = var_868)[name = tensor("op_869")]; tensor var_869_promoted_dtype_0 = const()[name = tensor("op_869_promoted_dtype_0"), val = tensor("fp32")]; tensor var_873 = const()[name = tensor("op_873"), val = tensor(-0x1p-1)]; tensor var_874 = greater(x = search_likelihood, y = var_873)[name = tensor("op_874")]; tensor var_874_promoted_dtype_0 = const()[name = tensor("op_874_promoted_dtype_0"), val = tensor("fp32")]; tensor var_869_promoted = cast(dtype = var_869_promoted_dtype_0, x = var_869)[name = tensor("cast_37")]; tensor var_878 = mul(x = search_likelihood, y = var_869_promoted)[name = tensor("op_878")]; tensor max_like_keep_dims_0 = const()[name = tensor("max_like_keep_dims_0"), val = tensor(false)]; tensor max_like = reduce_max(keep_dims = max_like_keep_dims_0, x = var_878)[name = tensor("max_like")]; tensor var_880_promoted = const()[name = tensor("op_880_promoted"), val = tensor(0x1p+0)]; tensor var_882 = sub(x = var_880_promoted, y = var_869_promoted)[name = tensor("op_882")]; tensor var_883 = const()[name = tensor("op_883"), val = tensor(0x1.f4p+9)]; tensor var_884 = mul(x = var_882, y = var_883)[name = tensor("op_884")]; tensor var_886 = add(x = search_likelihood, y = var_884)[name = tensor("op_886")]; tensor min_like_keep_dims_0 = const()[name = tensor("min_like_keep_dims_0"), val = tensor(false)]; tensor min_like = reduce_min(keep_dims = min_like_keep_dims_0, x = var_886)[name = tensor("min_like")]; tensor var_889 = sub(x = search_likelihood, y = min_like)[name = tensor("op_889")]; tensor var_874_promoted = cast(dtype = var_874_promoted_dtype_0, x = var_874)[name = tensor("cast_36")]; tensor search_concentration = mul(x = var_889, y = var_874_promoted)[name = tensor("search_concentration")]; tensor var_891 = mul(x = search_concentration, y = var_869_promoted)[name = tensor("op_891")]; tensor like_sum_keep_dims_0 = const()[name = tensor("like_sum_keep_dims_0"), val = tensor(false)]; tensor like_sum = reduce_sum(keep_dims = like_sum_keep_dims_0, x = var_891)[name = tensor("like_sum")]; tensor var_895 = sub(x = max_like, y = min_like)[name = tensor("op_895")]; tensor var_896 = const()[name = tensor("op_896"), val = tensor(0x1.0624dep-10)]; tensor var_897 = less(x = like_sum, y = var_896)[name = tensor("op_897")]; tensor var_898 = const()[name = tensor("op_898"), val = tensor(0x1.0624dep-10)]; tensor var_897_promoted_dtype_0 = const()[name = tensor("op_897_promoted_dtype_0"), val = tensor("fp32")]; tensor var_897_promoted = cast(dtype = var_897_promoted_dtype_0, x = var_897)[name = tensor("cast_35")]; tensor var_899 = mul(x = var_897_promoted, y = var_898)[name = tensor("op_899")]; tensor var_901 = add(x = like_sum, y = var_899)[name = tensor("op_901")]; tensor var_902 = real_div(x = var_895, y = var_901)[name = tensor("op_902")]; tensor var_903 = mul(x = search_concentration, y = var_902)[name = tensor("op_903")]; tensor var_906 = const()[name = tensor("op_906"), val = tensor(0x1.00f406p+0)]; tensor var_907 = mul(x = var_903, y = var_906)[name = tensor("op_907")]; tensor var_909 = add(x = positive_counts_1, y = var_864)[name = tensor("op_909")]; tensor var_911 = add(x = var_909, y = var_907)[name = tensor("op_911")]; tensor var_913 = const()[name = tensor("op_913"), val = tensor(0x1.57e3a8p+1)]; tensor var_914 = add(x = var_911, y = var_913)[name = tensor("op_914")]; tensor var_915 = mul(x = var_914, y = var_51_to_fp32)[name = tensor("op_915")]; tensor var_917_axes_0 = const()[name = tensor("op_917_axes_0"), val = tensor([0])]; tensor var_917 = expand_dims(axes = var_917_axes_0, x = var_915)[name = tensor("op_917")]; tensor var_919_axes_0 = const()[name = tensor("op_919_axes_0"), val = tensor([0])]; tensor var_919 = expand_dims(axes = var_919_axes_0, x = x)[name = tensor("op_919")]; tensor var_921 = const()[name = tensor("op_921"), val = tensor(0)]; tensor var_922_interleave_0 = const()[name = tensor("op_922_interleave_0"), val = tensor(false)]; tensor rankings = concat(axis = var_921, interleave = var_922_interleave_0, values = (var_917, var_919))[name = tensor("op_922")]; tensor actionId = const()[name = tensor("const_0"), val = tensor([0x1p+0, 0x1p+0])]; tensor shadowActionId = const()[name = tensor("const_1"), val = tensor([0x1p+0, 0x1p+0])]; tensor actionCandidates = const()[name = tensor("const_2"), val = tensor([0x0p+0])]; tensor shadowActionCandidates = const()[name = tensor("const_3"), val = tensor([0x0p+0])]; tensor diagnostic = const()[name = tensor("const_4"), val = tensor([0x0p+0, 0x0p+0])]; tensor anonymizedHistory = const()[name = tensor("const_5"), val = tensor([[0x0p+0, 0x0p+0], [0x0p+0, 0x0p+0]])]; tensor forcedPrompt = const()[name = tensor("const_6"), val = tensor([0x0p+0])]; tensor tupleInteractions_candidates_tmp = identity(x = tupleInteractions_candidates)[name = tensor("tupleInteractions_candidates_tmp")]; tensor tupleInteractions_alignment_tmp = identity(x = tupleInteractions_alignment)[name = tensor("tupleInteractions_alignment_tmp")]; tensor similarityScores_tmp = identity(x = similarityScores)[name = tensor("similarityScores_tmp")]; tensor riskLevel_tmp = identity(x = riskLevel)[name = tensor("riskLevel_tmp")]; tensor candidate_risk_tmp = identity(x = candidate_risk)[name = tensor("candidate_risk_tmp")]; tensor forcedPromptRate_tmp = identity(x = forcedPromptRate)[name = tensor("forcedPromptRate_tmp")]; tensor parameterName_tmp = identity(x = parameterName)[name = tensor("parameterName_tmp")]; tensor alreadyPrompted_tmp = identity(x = alreadyPrompted)[name = tensor("alreadyPrompted_tmp")]; tensor isResolved_tmp = identity(x = isResolved)[name = tensor("isResolved_tmp")]; tensor component_tmp = identity(x = component)[name = tensor("component_tmp")]; } -> (actionId, actionCandidates, shadowActionId, shadowActionCandidates, rankings, diagnostic, anonymizedHistory, forcedPrompt); }