// -*- Metal -*- //===-- MetalTensorOpsUtility ------------------------------------------------------===// // Copyright (c) 2025 Apple Inc. All rights reserved //===----------------------------------------------------------------------===// #ifndef __MetalTensorOpsUtility__ #define __MetalTensorOpsUtility__ #if defined(__METAL_VERSION__) && defined(__HAVE_TENSOR__) #include "MPPTensorOpsTypes.h" namespace mpp { namespace tensor_ops { namespace __tensor_ops_detail { template T __MIN(T x, T y) { return ((x) < (y) ? (x) : (y)); } template T __MAX(T x, T y) { return ((x) > (y) ? (x) : (y)); } template struct __type_to_tensor_ops_datatype { static_assert(__tensor_ops_detail::__assert_false_v, "unsupported data type"); }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_float32; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_float16; }; #if __HAVE_BFLOAT__ template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_bfloat16; }; #endif template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_int8; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_uint8; }; #if __HAVE_INT4B_FORMAT_TYPE__ template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_int4; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_uint4; }; #endif template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_int16; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_uint16; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_int32; }; template <> struct __type_to_tensor_ops_datatype { static constant __tensor_ops_datatype value = __tensor_ops_datatype_uint32; }; } // namespace __tensor_ops_detail } // namespace tensor_ops } // namespace mpp #endif #endif // __TensorOpsUtility__