mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Fix LSTM backward and forward pass (#95137)
Fixes #91694 Fixes #92615 Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`. After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615. After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded Funny enough, backward tests were completely disabled before and were not passing: ```python @unittest.skipIf(True, "Backward of lstm returns wrong result") def test_lstm_2(self, device="mps", dtype=torch.float32): ``` UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137 Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
This commit is contained in:
parent
86efa104f5
commit
b9e95158d5
|
|
@ -1423,7 +1423,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
|
|||
}
|
||||
#ifdef USE_MPS
|
||||
if (_input.is_mps() && !bidirectional) {
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
|
||||
num_layers, dropout_p, train, bidirectional, batch_first);
|
||||
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
|
||||
return return_values;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
|
|||
return output_dimensions;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
using namespace mps;
|
||||
|
||||
//Projections are not currently supported, raise an error if needed
|
||||
|
|
@ -32,6 +32,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
AT_ERROR("LSTM with projections is not currently supported with MPS.");
|
||||
}
|
||||
|
||||
TORCH_CHECK(!(!is_macos_13_or_newer() && num_layers > 1), "Multi-layer LSTM support in MPS available only on MacOS 13 onwards");
|
||||
|
||||
std::vector<Tensor> kernel_weights;
|
||||
std::vector<Tensor> recurrent_kernel_weights;
|
||||
std::vector<Tensor> biases;
|
||||
|
|
@ -56,8 +58,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
|
||||
std::vector<MPSGraphTensor*> outputCellStateFwdVector_;
|
||||
std::vector<MPSGraphTensor*> outputZStateVector_;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
|
@ -79,6 +79,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
|
||||
NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
|
||||
for (size_t i = 0; i < num_layers; i += 1) {
|
||||
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
|
||||
|
|
@ -107,16 +108,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
}
|
||||
|
||||
MPSGraphTensor* inputTensor_ = inputTensor;
|
||||
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:0
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:0
|
||||
length:1
|
||||
name:nil];
|
||||
NSArray<MPSGraphTensor*>* outputs = nil;
|
||||
NSMutableArray<MPSGraphTensor*>* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
|
|
@ -129,6 +120,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
secondaryTensor:recurrentBiasList[i]
|
||||
name:nil];
|
||||
}
|
||||
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
|
||||
recurrentWeight:recurrentKernelWeightsList[i]
|
||||
inputWeight:kernelWeightsList[i]
|
||||
|
|
@ -138,17 +139,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
descriptor:opDesc
|
||||
name:nil];
|
||||
|
||||
stateTensor_ = [mpsGraph sliceTensor:stateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
|
||||
dimension:0
|
||||
start:i
|
||||
length:1
|
||||
name:nil];
|
||||
inputTensor_ = [outputs objectAtIndex:0];
|
||||
// no need to keep a final layer output copy as it is
|
||||
// returned anyway and not used in backprop
|
||||
if(i != num_layers - 1) {
|
||||
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
|
||||
axis:0
|
||||
name:nil]];
|
||||
}
|
||||
if(dropout_p>0.0 && train && (i!=num_layers-1)) {
|
||||
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
|
||||
rate:dropout_p
|
||||
|
|
@ -166,7 +164,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
name:nil]];
|
||||
}
|
||||
|
||||
MPSGraphTensor* outputTensor = [outputs objectAtIndex:0];
|
||||
MPSGraphTensor* outputTensor = inputTensor_;
|
||||
if (batch_first) {
|
||||
outputTensor = [mpsGraph transposeTensor:outputTensor
|
||||
dimension:0
|
||||
|
|
@ -185,8 +183,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray
|
||||
dimension:0
|
||||
name:nil];
|
||||
MPSGraphTensor* layersOutputs = (num_layers > 1)
|
||||
? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil]
|
||||
: nil;
|
||||
|
||||
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd};
|
||||
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
|
||||
newCachedGraph->inputTensors_ = inputTensors;
|
||||
newCachedGraph->outputTensors_ = outputTensors;
|
||||
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
|
||||
|
|
@ -204,10 +205,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
|
||||
|
||||
Placeholder kernelWeight;
|
||||
Placeholder recurrentKernelWeight;
|
||||
Placeholder bias;
|
||||
Placeholder recurrentBias;
|
||||
Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
for (size_t i = 0; i < num_layers; i+=1) {
|
||||
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
|
||||
|
|
@ -236,6 +235,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
Tensor cy = at::empty_like(hx[1], input.options());
|
||||
Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options());
|
||||
Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options());
|
||||
Tensor layerOutputs = (num_layers > 1)
|
||||
? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options())
|
||||
: at::empty({ 1 }, input.options()); // not used if num_layers == 1
|
||||
|
||||
Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output);
|
||||
Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy);
|
||||
|
|
@ -243,20 +245,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
|
|||
Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState);
|
||||
Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd);
|
||||
|
||||
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [@{
|
||||
outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(),
|
||||
outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(),
|
||||
outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(),
|
||||
outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(),
|
||||
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData()
|
||||
};
|
||||
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(),
|
||||
} mutableCopy];
|
||||
|
||||
if (num_layers > 1) {
|
||||
Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs);
|
||||
[results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey: outputPlaceholder5.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
return std::make_tuple(output, hy, cy, zState, cellStateFwd);
|
||||
return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
using namespace mps;
|
||||
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] {return Tensor();});
|
||||
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
|
||||
|
|
@ -287,12 +294,12 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradOutput_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradRecWeights_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradWeights_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradBias_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradState_ = nil;
|
||||
NSMutableArray<MPSGraphTensor*> *gradCellState_ = nil;
|
||||
MPSGraphTensor* gradOutput_ = nil;
|
||||
MPSGraphTensor* gradState_ = nil;
|
||||
MPSGraphTensor* gradCellState_ = nil;
|
||||
};
|
||||
|
||||
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
|
||||
|
|
@ -333,8 +340,22 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy.scalar_type()), getMPSShape(grad_cy));
|
||||
MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy.scalar_type()), getMPSShape(grad_hy));
|
||||
MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd.scalar_type()), getMPSShape(cell_state_fwd));
|
||||
MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs.scalar_type()), getMPSShape(layersOutputs));
|
||||
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor};
|
||||
|
||||
if (batch_first) {
|
||||
inputTensor = [mpsGraph transposeTensor: inputTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
|
||||
gradientTensor = [mpsGraph transposeTensor: gradientTensor
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name: nil];
|
||||
}
|
||||
|
||||
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor};
|
||||
newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList;
|
||||
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
|
||||
newCachedGraph->biasList_ = kernelBiasList;
|
||||
|
|
@ -350,7 +371,6 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
|
||||
NSArray<MPSGraphTensor*>* outputs = nil;
|
||||
|
||||
NSMutableArray<MPSGraphTensor*>* gradOutputArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
NSMutableArray<MPSGraphTensor*>* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
|
||||
|
|
@ -406,7 +426,23 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
length:1
|
||||
name:nil];
|
||||
|
||||
outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor
|
||||
MPSGraphTensor* iterationInputTensor_ = nil;
|
||||
if (i == 0) {
|
||||
iterationInputTensor_ = inputTensor;
|
||||
} else {
|
||||
iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
|
||||
dimension: 0
|
||||
// last element in layersOutputsTensor contains
|
||||
// **inputs** for the last layer
|
||||
start: i - num_layers
|
||||
length: 1
|
||||
name: nil];
|
||||
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_
|
||||
axis:0
|
||||
name: nil];
|
||||
}
|
||||
|
||||
outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
|
||||
recurrentWeight: recurrentKernelWeightsList[i]
|
||||
sourceGradient: gradientTensor_
|
||||
zState: zState
|
||||
|
|
@ -423,22 +459,30 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
name: nil];
|
||||
|
||||
gradientTensor_ = [outputs objectAtIndex:0];
|
||||
[gradOutputArray addObject:[outputs objectAtIndex:0]];
|
||||
[gradRecWeightsArray addObject:[outputs objectAtIndex:1]];
|
||||
[gradWeightsArray addObject:[outputs objectAtIndex:2]];
|
||||
[gradBiasArray addObject:[outputs objectAtIndex:3]];
|
||||
[gradStateArray addObject:[outputs objectAtIndex:4]];
|
||||
[gradCellStateArray addObject:[outputs objectAtIndex:5]];
|
||||
[gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
|
||||
[gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
|
||||
[gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
|
||||
[gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0];
|
||||
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
|
||||
}
|
||||
std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
|
||||
|
||||
if (batch_first) {
|
||||
MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
|
||||
dimension: 0
|
||||
withDimension: 1
|
||||
name:nil];
|
||||
newCachedGraph->gradOutput_ = gradientTensorTransposed;
|
||||
} else {
|
||||
newCachedGraph->gradOutput_ = gradientTensor_;
|
||||
}
|
||||
|
||||
newCachedGraph->outputTensors_ = outputTensors;
|
||||
newCachedGraph->gradOutput_ = gradOutputArray;
|
||||
newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
|
||||
newCachedGraph->gradWeights_ = gradWeightsArray;
|
||||
newCachedGraph->gradBias_ = gradBiasArray;
|
||||
newCachedGraph->gradState_ = gradStateArray;
|
||||
newCachedGraph->gradCellState_ = gradCellStateArray;
|
||||
|
||||
newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension: 0 name: nil];
|
||||
newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension: 0 name: nil];
|
||||
}
|
||||
return newCachedGraph;
|
||||
});
|
||||
|
|
@ -453,6 +497,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd);
|
||||
Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy);
|
||||
Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy);
|
||||
Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs);
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
[feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()];
|
||||
|
|
@ -463,6 +508,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
[feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
|
||||
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
NSMutableArray<MPSGraphTensor*> *kernelWeightsList = cachedGraph->kernelWeightsList_;
|
||||
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
|
||||
|
|
@ -485,62 +531,55 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
|
|||
}
|
||||
}
|
||||
|
||||
Tensor output = at::empty_like(input);
|
||||
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[0]);
|
||||
Tensor grad_weights = at::empty_like(kernel_weights[0]);
|
||||
Tensor grad_bias = at::empty((kernel_weights[0].size(0)), kernel_weights[0].options());
|
||||
Tensor grad_state = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state = at::empty_like(hx[1]);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensors_[0], output);
|
||||
Placeholder gradRecWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[1], grad_rec_weights);
|
||||
Placeholder gradWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[2], grad_weights);
|
||||
Placeholder gradBiasPlaceholder = Placeholder(cachedGraph->outputTensors_[3], grad_bias);
|
||||
Placeholder gradStatePlaceholder = Placeholder(cachedGraph->outputTensors_[4], grad_state);
|
||||
Placeholder gradCellStatePlaceholder = Placeholder(cachedGraph->outputTensors_[5], grad_cell_state);
|
||||
Tensor output_out = at::empty_like(input);
|
||||
Tensor grad_state_out = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state_out = at::empty_like(hx[1]);
|
||||
|
||||
std::vector<Tensor> grad_hx = {grad_state, grad_cell_state};
|
||||
|
||||
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
|
||||
|
||||
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = [[[NSMutableDictionary alloc] init] autorelease];
|
||||
NSMutableArray<MPSGraphTensor*> *gradOutputArray = cachedGraph->gradOutput_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradRecWeightsArray = cachedGraph->gradRecWeights_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradWeightsArray = cachedGraph->gradWeights_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradBiasArray = cachedGraph->gradBias_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradStateArray = cachedGraph->gradState_;
|
||||
NSMutableArray<MPSGraphTensor*> *gradCellStateArray = cachedGraph->gradCellState_;
|
||||
Placeholder gradOutPlaceholder;
|
||||
MPSGraphTensor* gradOutput = cachedGraph->gradOutput_;
|
||||
MPSGraphTensor* gradState = cachedGraph->gradState_;
|
||||
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_;
|
||||
|
||||
Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out);
|
||||
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
|
||||
Placeholder outputPlaceholder = Placeholder(gradOutput, output_out);
|
||||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
|
||||
|
||||
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
|
||||
|
||||
std::vector<Tensor> weights;
|
||||
for (int i = 0; i < num_layers; i++) {
|
||||
Tensor output = at::empty_like(input);
|
||||
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
|
||||
Tensor grad_weights = at::empty_like(kernel_weights[i]);
|
||||
Tensor grad_bias = at::empty((kernel_weights[0].size(0)), kernel_weights[0].options());
|
||||
Tensor grad_state = at::empty_like(hx[0]);
|
||||
Tensor grad_cell_state = at::empty_like(hx[1]);
|
||||
Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
|
||||
weights.push_back(grad_weights);
|
||||
weights.push_back(grad_rec_weights);
|
||||
|
||||
if(has_biases) {
|
||||
weights.push_back(grad_bias);
|
||||
weights.push_back(grad_bias);
|
||||
}
|
||||
gradOutPlaceholder = Placeholder([gradOutputArray objectAtIndex:i], output);
|
||||
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights);
|
||||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights);
|
||||
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:i], grad_bias);
|
||||
gradStatePlaceholder = Placeholder([gradStateArray objectAtIndex:i], grad_state);
|
||||
gradCellStatePlaceholder = Placeholder([gradCellStateArray objectAtIndex:i], grad_cell_state);
|
||||
|
||||
[results setObject:gradOutPlaceholder.getMPSGraphTensorData() forKey:gradOutPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
|
||||
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
|
||||
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
|
||||
|
||||
[results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
|
||||
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
|
||||
}
|
||||
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
|
||||
|
||||
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output, grad_hx, weights);
|
||||
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output_out, grad_hx, weights);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7200,12 +7200,12 @@
|
|||
|
||||
# MPS LSTM implementation
|
||||
|
||||
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
MPS: _lstm_mps
|
||||
autogen: _lstm_mps.out
|
||||
|
||||
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
dispatch:
|
||||
MPS: lstm_mps_backward
|
||||
autogen: lstm_mps_backward.out
|
||||
|
|
|
|||
115
test/test_mps.py
115
test/test_mps.py
|
|
@ -9008,64 +9008,91 @@ class TestAdvancedIndexing(TestCaseMPS):
|
|||
|
||||
class TestRNNMPS(TestCaseMPS):
|
||||
def test_lstm_1(self, device="mps", dtype=torch.float32):
|
||||
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
|
||||
torch.random.manual_seed(42)
|
||||
rnn = nn.LSTM(7, 4, layers, device="cpu")
|
||||
input = torch.randn(2, 3, 7, device="cpu")
|
||||
hx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cx = torch.randn(layers, 3, 4, device="cpu")
|
||||
|
||||
rnn = nn.LSTM(1, 4, 2, device="cpu")
|
||||
input = torch.randn(2, 3, 1, device="cpu")
|
||||
hx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
# test batch_first
|
||||
rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
|
||||
input = torch.randn(3, 2, 7, device="cpu")
|
||||
hx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cx = torch.randn(layers, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
|
||||
# test batch_first
|
||||
rnn = nn.LSTM(1, 4, 2, device="cpu", batch_first=True)
|
||||
input = torch.randn(3, 2, 1, device="cpu")
|
||||
hx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cx = torch.zeros(2, 3, 4, device="cpu")
|
||||
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
|
||||
rnn = rnn.to(device)
|
||||
input = input.to(device)
|
||||
hx = hx.to(device)
|
||||
cx = cx.to(device)
|
||||
output, (hn, cn) = rnn(input, (hx, cx))
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
|
||||
self.assertEqual(cpu_output, output)
|
||||
self.assertEqual(cpu_hn, hn)
|
||||
self.assertEqual(cpu_cn, cn)
|
||||
def test_lstm_backward(self, device="mps", dtype=torch.float32):
|
||||
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
|
||||
lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init
|
||||
lstm.train()
|
||||
|
||||
@unittest.skipIf(True, "Backward of lstm returns wrong result")
|
||||
def test_lstm_2(self, device="mps", dtype=torch.float32):
|
||||
def get_results(device):
|
||||
rnn = nn.LSTM(1, 4, 1, device=device)
|
||||
inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
|
||||
hx = torch.zeros(1, 3, 4, device=device)
|
||||
cx = torch.zeros(1, 3, 4, device=device)
|
||||
def get_results(device, inp, hx, cx):
|
||||
rnn = lstm.to(device)
|
||||
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
|
||||
|
||||
output, _ = rnn(inp, (hx, cx))
|
||||
output.sum().backward()
|
||||
output, _ = rnn(inp, (hx, cx))
|
||||
f = output.sum()
|
||||
|
||||
weight_grad = rnn.weight_ih_l0.grad.clone()
|
||||
input_grad = inp.grad.clone()
|
||||
param_names, params = zip(*rnn.named_parameters())
|
||||
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
|
||||
|
||||
return output, weight_grad, input_grad
|
||||
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
|
||||
return output, param_grads, input_grad, hx_grad, cx_grad
|
||||
|
||||
inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
|
||||
hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
|
||||
cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
|
||||
mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
|
||||
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
|
||||
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
|
||||
|
||||
self.assertEqual(cpu_hx_grad, mps_hx_grad)
|
||||
self.assertEqual(cpu_cx_grad, mps_cx_grad)
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
|
||||
|
||||
# test batch_first backward
|
||||
lstm = nn.LSTM(2, 4, layers, batch_first=True)
|
||||
lstm.train()
|
||||
|
||||
hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
|
||||
|
||||
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
|
||||
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
|
||||
|
||||
self.assertEqual(cpu_hx_grad, mps_hx_grad)
|
||||
self.assertEqual(cpu_cx_grad, mps_cx_grad)
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
|
||||
|
||||
self.assertEqual(cpu_output, mps_output)
|
||||
self.assertEqual(cpu_input_grad, mps_input_grad)
|
||||
self.assertEqual(cpu_weight_grad, mps_weight_grad)
|
||||
|
||||
def test_RNN_cell_no_broadcasting(self):
|
||||
def test(cell_module, input, hx, input_size, hidden_size):
|
||||
|
|
|
|||
|
|
@ -2568,11 +2568,11 @@
|
|||
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
|
||||
|
||||
#LSTM MPS
|
||||
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
output_differentiability: [True, True, True, False, False]
|
||||
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
|
||||
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
|
||||
output_differentiability: [True, True, True, False, False, False]
|
||||
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
|
||||
|
||||
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1109,6 +1109,7 @@ SUPPORTED_RETURN_TYPES = {
|
|||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
|
||||
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user