program(1.3) [IX_checkpoint = string("s3://laticetorch-data/jp/models/ckpt/sydney_lstm_self_attn_m4_laticetorch.pt"), IX_export = string("bolt.apple.com/tasks/brnugxsxez"), IX_zooml = string("https://zooml.apple.com/model/8mnq2zlnq3"), mldb_token = string("mldb-dlqgw2jg9p")] { func batch32(tensor history_pad_mask, tensor input, tensor lstm_0_c_in, tensor lstm_0_h_history, tensor lstm_0_h_in) { tensor input_int32 = cast(dtype = string("int32"), x = input); tensor embeddings = gather(axis = int32(0), indices = input_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(64)))); tensor lstm_0, tensor lstm_0_h_out, tensor lstm_0_c_out = lstm(activation = string("tanh"), bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4104320))), cell_activation = string("tanh"), direction = string("forward"), initial_c = lstm_0_c_in, initial_h = lstm_0_h_in, output_sequence = bool(true), recurrent_activation = string("sigmoid"), weight_hh = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4105408))), weight_ih = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4236544))), x = embeddings); tensor history_pad_mask_fp16 = cast(dtype = string("fp16"), x = history_pad_mask); tensor history_pad_mask_3d = expand_dims(axes = tensor([1]), x = history_pad_mask_fp16); tensor history_softmax_mask = mul(x = history_pad_mask_3d, y = fp16(-10000)); tensor lstm_0_T = transpose(perm = tensor([1, 0, 2]), x = lstm_0); tensor self_attn_0_dot = matmul(transpose_x = bool(false), transpose_y = bool(true), x = lstm_0_T, y = lstm_0_h_history); tensor self_attn_0_masked = add(x = self_attn_0_dot, y = history_softmax_mask); tensor self_attn_0_weights = softmax(axis = int32(-1), x = self_attn_0_masked); tensor self_attn_0_ctx_T = matmul(transpose_x = bool(false), transpose_y = bool(false), x = self_attn_0_weights, y = lstm_0_h_history); tensor self_attn_0_ctx = transpose(perm = tensor([1, 0, 2]), x = self_attn_0_ctx_T); tensor self_attn_0_h_ctx = concat(axis = int32(-1), interleave = bool(false), values = (lstm_0, self_attn_0_ctx)); tensor self_attn_0_proj = linear(weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4367680))), x = self_attn_0_h_ctx); tensor self_attn_0 = tanh(x = self_attn_0_proj); tensor logits = linear(bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4433280))), weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4465408))), x = self_attn_0); } -> (lstm_0_h_out, lstm_0_c_out, logits); func logJointProb(tensor history_pad_mask, tensor input, tensor label, tensor lstm_0_c_in, tensor lstm_0_h_history, tensor lstm_0_h_in) { tensor input_int32 = cast(dtype = string("int32"), x = input); tensor embeddings = gather(axis = int32(0), indices = input_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(64)))); tensor lstm_0, tensor lstm_0_h_out, tensor lstm_0_c_out = lstm(activation = string("tanh"), bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4104320))), cell_activation = string("tanh"), direction = string("forward"), initial_c = lstm_0_c_in, initial_h = lstm_0_h_in, output_sequence = bool(true), recurrent_activation = string("sigmoid"), weight_hh = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4105408))), weight_ih = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4236544))), x = embeddings); tensor history_pad_mask_fp16 = cast(dtype = string("fp16"), x = history_pad_mask); tensor history_pad_mask_3d = expand_dims(axes = tensor([1]), x = history_pad_mask_fp16); tensor history_softmax_mask = mul(x = history_pad_mask_3d, y = fp16(-10000)); tensor lstm_0_T = transpose(perm = tensor([1, 0, 2]), x = lstm_0); tensor self_attn_0_dot = matmul(transpose_x = bool(false), transpose_y = bool(true), x = lstm_0_T, y = lstm_0_h_history); tensor self_attn_0_masked = add(x = self_attn_0_dot, y = history_softmax_mask); tensor self_attn_0_weights = softmax(axis = int32(-1), x = self_attn_0_masked); tensor self_attn_0_ctx_T = matmul(transpose_x = bool(false), transpose_y = bool(false), x = self_attn_0_weights, y = lstm_0_h_history); tensor self_attn_0_ctx = transpose(perm = tensor([1, 0, 2]), x = self_attn_0_ctx_T); tensor self_attn_0_h_ctx = concat(axis = int32(-1), interleave = bool(false), values = (lstm_0, self_attn_0_ctx)); tensor self_attn_0_proj = linear(weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4367680))), x = self_attn_0_h_ctx); tensor self_attn_0 = tanh(x = self_attn_0_proj); tensor label_int32 = cast(dtype = string("int32"), x = label); tensor label_embeddings = gather(axis = int32(0), indices = label_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4465408)))); tensor label_embeddings_nd = expand_dims(axes = tensor([-1]), x = label_embeddings); tensor model_embeddings_nd = expand_dims(axes = tensor([-1]), x = self_attn_0); tensor logits_without_biases_nd = matmul(transpose_x = bool(true), transpose_y = bool(false), x = model_embeddings_nd, y = label_embeddings_nd); tensor logits_without_biases = squeeze(axes = tensor([-1, -2]), x = logits_without_biases_nd); tensor logits_biases = gather(axis = int32(0), indices = label_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4433280)))); tensor logits = add(x = logits_without_biases, y = logits_biases); } -> (lstm_0_h_out, lstm_0_c_out, logits); func logJointProb_batch32(tensor history_pad_mask, tensor input, tensor label, tensor lstm_0_c_in, tensor lstm_0_h_history, tensor lstm_0_h_in) { tensor input_int32 = cast(dtype = string("int32"), x = input); tensor embeddings = gather(axis = int32(0), indices = input_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(64)))); tensor lstm_0, tensor lstm_0_h_out, tensor lstm_0_c_out = lstm(activation = string("tanh"), bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4104320))), cell_activation = string("tanh"), direction = string("forward"), initial_c = lstm_0_c_in, initial_h = lstm_0_h_in, output_sequence = bool(true), recurrent_activation = string("sigmoid"), weight_hh = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4105408))), weight_ih = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4236544))), x = embeddings); tensor history_pad_mask_fp16 = cast(dtype = string("fp16"), x = history_pad_mask); tensor history_pad_mask_3d = expand_dims(axes = tensor([1]), x = history_pad_mask_fp16); tensor history_softmax_mask = mul(x = history_pad_mask_3d, y = fp16(-10000)); tensor lstm_0_T = transpose(perm = tensor([1, 0, 2]), x = lstm_0); tensor self_attn_0_dot = matmul(transpose_x = bool(false), transpose_y = bool(true), x = lstm_0_T, y = lstm_0_h_history); tensor self_attn_0_masked = add(x = self_attn_0_dot, y = history_softmax_mask); tensor self_attn_0_weights = softmax(axis = int32(-1), x = self_attn_0_masked); tensor self_attn_0_ctx_T = matmul(transpose_x = bool(false), transpose_y = bool(false), x = self_attn_0_weights, y = lstm_0_h_history); tensor self_attn_0_ctx = transpose(perm = tensor([1, 0, 2]), x = self_attn_0_ctx_T); tensor self_attn_0_h_ctx = concat(axis = int32(-1), interleave = bool(false), values = (lstm_0, self_attn_0_ctx)); tensor self_attn_0_proj = linear(weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4367680))), x = self_attn_0_h_ctx); tensor self_attn_0 = tanh(x = self_attn_0_proj); tensor label_int32 = cast(dtype = string("int32"), x = label); tensor label_embeddings = gather(axis = int32(0), indices = label_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4465408)))); tensor label_embeddings_nd = expand_dims(axes = tensor([-1]), x = label_embeddings); tensor model_embeddings_nd = expand_dims(axes = tensor([-1]), x = self_attn_0); tensor logits_without_biases_nd = matmul(transpose_x = bool(true), transpose_y = bool(false), x = model_embeddings_nd, y = label_embeddings_nd); tensor logits_without_biases = squeeze(axes = tensor([-1, -2]), x = logits_without_biases_nd); tensor logits_biases = gather(axis = int32(0), indices = label_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4433280)))); tensor logits = add(x = logits_without_biases, y = logits_biases); } -> (lstm_0_h_out, lstm_0_c_out, logits); func main(tensor history_pad_mask, tensor input, tensor lstm_0_c_in, tensor lstm_0_h_history, tensor lstm_0_h_in) { tensor input_int32 = cast(dtype = string("int32"), x = input); tensor embeddings = gather(axis = int32(0), indices = input_int32, validate_indices = bool(false), x = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(64)))); tensor lstm_0, tensor lstm_0_h_out, tensor lstm_0_c_out = lstm(activation = string("tanh"), bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4104320))), cell_activation = string("tanh"), direction = string("forward"), initial_c = lstm_0_c_in, initial_h = lstm_0_h_in, output_sequence = bool(true), recurrent_activation = string("sigmoid"), weight_hh = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4105408))), weight_ih = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4236544))), x = embeddings); tensor history_pad_mask_fp16 = cast(dtype = string("fp16"), x = history_pad_mask); tensor history_pad_mask_3d = expand_dims(axes = tensor([1]), x = history_pad_mask_fp16); tensor history_softmax_mask = mul(x = history_pad_mask_3d, y = fp16(-10000)); tensor lstm_0_T = transpose(perm = tensor([1, 0, 2]), x = lstm_0); tensor self_attn_0_dot = matmul(transpose_x = bool(false), transpose_y = bool(true), x = lstm_0_T, y = lstm_0_h_history); tensor self_attn_0_masked = add(x = self_attn_0_dot, y = history_softmax_mask); tensor self_attn_0_weights = softmax(axis = int32(-1), x = self_attn_0_masked); tensor self_attn_0_ctx_T = matmul(transpose_x = bool(false), transpose_y = bool(false), x = self_attn_0_weights, y = lstm_0_h_history); tensor self_attn_0_ctx = transpose(perm = tensor([1, 0, 2]), x = self_attn_0_ctx_T); tensor self_attn_0_h_ctx = concat(axis = int32(-1), interleave = bool(false), values = (lstm_0, self_attn_0_ctx)); tensor self_attn_0_proj = linear(weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4367680))), x = self_attn_0_h_ctx); tensor self_attn_0 = tanh(x = self_attn_0_proj); tensor logits = linear(bias = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4433280))), weight = tensor(BLOBFILE(path = string("@model_path/lstm.mil.weights"), offset = uint64(4465408))), x = self_attn_0); } -> (lstm_0_h_out, lstm_0_c_out, logits); }