mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[MPS] Fix ops with bool issues in macOS Monterey (#94464)
Summary: - Remove redundant bool casts from scatter/gather - Make the workarounds for scatter/gather (for bool/uint8 data types) OS specific - use them only in macOS Monterey, ignore them starting with macOS Ventura - Make all tensors ranked in scatter Fixes following tests: ``` test_output_match_slice_scatter_cpu_bool test_output_match_select_scatter_cpu_bool test_output_match_diagonal_scatter_cpu_bool test_output_match_repeat_cpu_bool test_output_match_rot90_cpu_bool etc.. ``` Still failing on macOS Monterey (needs additional investigation): ``` test_output_match_scatter_cpu_bool ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94464 Approved by: https://github.com/kulinseth
This commit is contained in:
parent
5b1cedacde
commit
728dfeee48
|
|
@ -439,7 +439,16 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
|||
using CachedGraph = mps::MPSUnaryCachedGraph;
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
||||
MPSDataType inputDataType = getMPSScalarType(self.scalar_type());
|
||||
MPSDataType outputDataType = getMPSScalarType(self.scalar_type());
|
||||
if (!is_macos_13_or_newer()) {
|
||||
if (self.scalar_type() == kBool) {
|
||||
inputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
if (result.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
}
|
||||
@autoreleasepool {
|
||||
NSString* ns_dims_key = [[ns_dims valueForKey:@"description"] componentsJoinedByString:@","];
|
||||
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
|
||||
|
|
@ -454,7 +463,7 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
|||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph reverseTensor:inputTensor
|
||||
axes:ns_dims
|
||||
name:nil];
|
||||
|
|
@ -466,8 +475,10 @@ Tensor flip_mps(const Tensor& self, IntArrayRef dims) {
|
|||
}
|
||||
|
||||
// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
Placeholder inputPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, result, /*mpsShape*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
|
|
@ -656,12 +667,15 @@ Tensor& index_select_out_mps(const Tensor & self,
|
|||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
auto inputType = getMPSDataType(self.scalar_type());
|
||||
auto outputType = getMPSDataType(output.scalar_type());
|
||||
if (inputType == MPSDataTypeUInt8 || inputType == MPSDataTypeBool) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
if (inputType == MPSDataTypeUInt8 ||
|
||||
(!is_macos_13_or_newer() && inputType == MPSDataTypeBool)) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
}
|
||||
if (outputType == MPSDataTypeUInt8 || outputType == MPSDataTypeBool) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
if (outputType == MPSDataTypeUInt8 ||
|
||||
(!is_macos_13_or_newer() && outputType == MPSDataTypeBool)) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
string key = "index_select_out_mps" + getTensorsStringKey({self, index}) + ":" + std::to_string(dim);
|
||||
|
|
@ -792,10 +806,11 @@ Tensor & masked_fill__mps(Tensor& self, const Tensor & mask, const Scalar& value
|
|||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, self, /*mpsShape*/nullptr, /*gatherTensorData=*/true, inputDataType);
|
||||
cachedGraph->inputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder maskPlaceholder = Placeholder(
|
||||
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nullptr, /*gatherTensorData=*/true, maskDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, self);
|
||||
cachedGraph->maskTensor_, *b_mask, /*mpsShape*/nil, /*gatherTensorData=*/true, maskDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, self, /*mpsShape*/nil, /*gatherTensorData=*/false, inputDataType);
|
||||
|
||||
// Create dictionary of inputs and outputs
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
|
|
|
|||
|
|
@ -71,6 +71,16 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
|||
}
|
||||
|
||||
auto stream = at::mps::getCurrentMPSStream();
|
||||
auto inputDataType = getMPSDataType(expanded_tensor.scalar_type());
|
||||
auto outputDataType = getMPSDataType(result.scalar_type());
|
||||
if (!is_macos_13_or_newer()) {
|
||||
if (expanded_tensor.scalar_type() == kBool) {
|
||||
inputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
if (result.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
|
||||
|
|
@ -84,7 +94,7 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
|||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, expanded_tensor);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(expanded_tensor));
|
||||
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
|
||||
withMultiplier:getMPSShape(repeats)
|
||||
name:nil];
|
||||
|
|
@ -97,8 +107,10 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
|
|||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, expanded_tensor);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
Placeholder selfPlaceholder = Placeholder(
|
||||
cachedGraph->inputTensor_, expanded_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, inputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, result, /*mpsShape=*/nil, /*gatherTensorData*/false, outputDataType);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
|
||||
|
|
|
|||
|
|
@ -51,11 +51,13 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
|||
if(i != dim && [index_shape[i] intValue] < [input_shape[i] intValue])
|
||||
needSlice = true;
|
||||
}
|
||||
// input and output types are always the same
|
||||
auto dtype = getMPSDataType(self.scalar_type());
|
||||
// workaround for UInt8 and Bool issues in MPS backend
|
||||
if (dtype == MPSDataTypeUInt8 || dtype == MPSDataTypeBool) {
|
||||
dtype = MPSDataTypeInt8;
|
||||
auto input_type = getMPSDataType(self.scalar_type());
|
||||
auto output_type = getMPSDataType(output.scalar_type());
|
||||
if (input_type == MPSDataTypeUInt8 || ((input_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
input_type = MPSDataTypeInt8;
|
||||
}
|
||||
if (output_type == MPSDataTypeUInt8 || ((output_type == MPSDataTypeBool && !is_macos_13_or_newer()))) {
|
||||
output_type = MPSDataTypeInt8;
|
||||
}
|
||||
string key = "gather_out_mps" + getTensorsStringKey({self, index, output}) + ":" + std::to_string(dim);
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
|
|
@ -68,7 +70,7 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
|||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
newCachedGraph = new CachedGraph(mpsGraph);
|
||||
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, dtype, input_shape);
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, getMPSShape(self));
|
||||
MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index);
|
||||
|
||||
MPSGraphTensor* getInput = inputTensor;
|
||||
|
|
@ -111,9 +113,9 @@ TORCH_IMPL_FUNC(gather_out_mps)
|
|||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, dtype);
|
||||
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, true, input_type);
|
||||
Placeholder indexPlaceholder = Placeholder(cachedGraph->indexTensor_, index, index_shape);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, dtype);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nullptr, false, output_type);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
|
||||
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
|
||||
|
|
|
|||
|
|
@ -392,7 +392,8 @@ TORCH_IMPL_FUNC(cat_out_mps)
|
|||
if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
|
||||
outputDataType = MPSDataTypeInt8;
|
||||
}
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, nil, false, outputDataType);
|
||||
Placeholder outputPlaceholder = Placeholder(
|
||||
cachedGraph->outputTensor_, out, /*mpsShape=*/nil, /*gatherTensorData=*/false, outputDataType);
|
||||
|
||||
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
|
||||
for (auto& inputPlaceholder : inputPlaceholders) {
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
|||
dataType: inputType] autorelease];
|
||||
if (needsScatter) {
|
||||
auto updatesType = getMPSScalarType(src.scalar_type());
|
||||
if (updatesType == MPSDataTypeUInt8 || updatesType == MPSDataTypeBool) {
|
||||
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
}
|
||||
|
||||
|
|
@ -69,10 +69,10 @@ static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src,
|
|||
strideScalars[i] = getMPSScalar(strides[i], ScalarType::Int);
|
||||
feeds[cachedGraph->strideTensors[i]] = getMPSGraphTensorFromScalar(stream, strideScalars[i]);
|
||||
}
|
||||
// Workaround for MPSShaderLibrary bug
|
||||
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
|
||||
auto outputType = getMPSDataType(output.scalar_type());
|
||||
if (outputType == MPSDataTypeUInt8) {
|
||||
// Workaround for MPSShaderLibrary bug in macOS Monterey
|
||||
// This is fixed in macOS Ventura
|
||||
auto outputType = getMPSScalarType(output.scalar_type());
|
||||
if (outputType == MPSDataTypeUInt8 || (outputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
outputType = MPSDataTypeInt8;
|
||||
}
|
||||
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: outputBuffer
|
||||
|
|
@ -505,7 +505,6 @@ MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mp
|
|||
static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const IntArrayRef& size,
|
||||
const IntArrayRef& stride, int64_t offset,
|
||||
const IntArrayRef& base_shape, bool needsScatter,
|
||||
const bool needsBoolCast,
|
||||
MPSGraphTensor* updatesTensor)
|
||||
{
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
|
|
@ -548,23 +547,9 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
|||
name: nil];
|
||||
MPSGraphTensor *inputTensor = cachedGraph->inputTensor;
|
||||
|
||||
// Workaround for bool scatter/gather deficiency
|
||||
// See https://github.com/pytorch/pytorch/issues/82663
|
||||
if (needsBoolCast) {
|
||||
inputTensor = [mpsGraph castTensor:inputTensor
|
||||
toType:MPSDataTypeInt8
|
||||
name:@"Cast away from bool"];
|
||||
}
|
||||
|
||||
if (!needsScatter) {
|
||||
MPSGraphTensor *outputTensor = asStridedLayer_pattern(mpsGraph, inputTensor, shape_size, size, stride, offset);
|
||||
|
||||
if (outputTensor) {
|
||||
if (needsBoolCast) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"Cast back to bool"];
|
||||
}
|
||||
return outputTensor;
|
||||
}
|
||||
}
|
||||
|
|
@ -597,14 +582,6 @@ static MPSGraphTensor* chainViewOperation(ViewCachedGraph* cachedGraph, const In
|
|||
withShapeTensor: shapeTensor
|
||||
name: nil];
|
||||
}
|
||||
|
||||
// Workaround for bool scatter/gather deficiency
|
||||
// See https://github.com/pytorch/pytorch/issues/82663
|
||||
if (needsBoolCast) {
|
||||
outputTensor = [mpsGraph castTensor:outputTensor
|
||||
toType:MPSDataTypeBool
|
||||
name:@"Cast back to bool"];
|
||||
}
|
||||
}
|
||||
return outputTensor;
|
||||
}
|
||||
|
|
@ -660,13 +637,13 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
|||
MPSGraph* mpsGraph = make_mps_graph();
|
||||
MPSGraphTensor* updatesTensor = nil;
|
||||
newCachedGraph = new ViewCachedGraph(mpsGraph);
|
||||
// Workaround for MPSShaderLibrary bug
|
||||
// TODO: Remove once https://github.com/pytorch/pytorch/issues/82305 is resolved
|
||||
// Workaround for MPSShaderLibrary bug in macOS Monterey
|
||||
// This is fixed in macOS Ventura
|
||||
auto inputType = getMPSScalarType(self.scalar_type());
|
||||
if (inputType == MPSDataTypeUInt8) {
|
||||
if (inputType == MPSDataTypeUInt8 || (inputType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
inputType = MPSDataTypeInt8;
|
||||
}
|
||||
auto needsBoolCast = inputType == MPSDataTypeBool;
|
||||
|
||||
// Self is the input tensor we are creating view of
|
||||
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape));
|
||||
newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[@1]);
|
||||
|
|
@ -675,10 +652,10 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
|||
}
|
||||
if (needsScatter) {
|
||||
auto updatesType = getMPSScalarType(updates.scalar_type());
|
||||
if (updatesType == MPSDataTypeUInt8) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
if (updatesType == MPSDataTypeUInt8 || (updatesType == MPSDataTypeBool && !is_macos_13_or_newer())) {
|
||||
updatesType = MPSDataTypeInt8;
|
||||
}
|
||||
newCachedGraph->updatesTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, updatesType);
|
||||
newCachedGraph->updatesTensor = mpsGraphRankedPlaceHolder(mpsGraph, updatesType, getMPSShape(self.numel()));
|
||||
updatesTensor = newCachedGraph->updatesTensor;
|
||||
if (inputType != updatesType) {
|
||||
updatesTensor = [mpsGraph castTensor:updatesTensor
|
||||
|
|
@ -686,7 +663,7 @@ static ViewCachedGraph* createViewGraph(const Tensor& self, const Tensor &update
|
|||
name:@"castUpdatesTensor"];
|
||||
}
|
||||
}
|
||||
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, needsBoolCast, updatesTensor);
|
||||
newCachedGraph->outputTensor = chainViewOperation(newCachedGraph, size, stride, storage_offset, base_shape, needsScatter, updatesTensor);
|
||||
}
|
||||
return newCachedGraph;
|
||||
}));
|
||||
|
|
|
|||
|
|
@ -8718,7 +8718,7 @@ class TestConsistency(TestCase):
|
|||
'diag': ['f32', 'i32'],
|
||||
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'diagflat': ['f32', 'i32'],
|
||||
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'diagonal_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'dist': ['f32'],
|
||||
'dot': ['f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
|
|
@ -8840,25 +8840,25 @@ class TestConsistency(TestCase):
|
|||
'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
|
||||
'remainder' : ['f32', 'f16'],
|
||||
'repeat': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'repeat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'repeat_interleave': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'resize_': ['b8', 'i16', 'i32', 'i64', 'u8'],
|
||||
'resize_as_': ['b8', 'i16', 'i32', 'i64', 'u8'],
|
||||
'resolve_conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'resolve_neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'rot90': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'rot90': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'round': ['f32', 'f16', 'i16', 'i32', 'i64'],
|
||||
'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'],
|
||||
'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'select_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
|
||||
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
|
||||
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],
|
||||
'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'],
|
||||
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'slice_scatter': ['b8', 'u8', 'f16', 'f32', 'i16', 'i32', 'i64'],
|
||||
'softmax': ['f32'],
|
||||
'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
|
||||
|
|
@ -9144,7 +9144,6 @@ class TestConsistency(TestCase):
|
|||
'pow': [torch.int64],
|
||||
'select_scatter': [torch.uint8],
|
||||
'sigmoid': [torch.int64],
|
||||
'slice_scatter': [torch.uint8],
|
||||
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user