[MPS] Fix LSTM backward and forward pass (#95137)

Fixes #91694
Fixes #92615

Several transpositions were missing for backward graph in case of `batch_first=True`. The #91694 is not reproduced with `batch_first=False`.

After fixing transpose issue, I finally thought that now I can use LSTM freely in my project. And then I got horrific results on train. Seems related to #92615.

After that I decided to fix LSTM's backward step completely. I collected all my findings in this thread — seems like I succeeded

Funny enough, backward tests were completely disabled before and were not passing:
```python
    @unittest.skipIf(True, "Backward of lstm returns wrong result")
    def test_lstm_2(self, device="mps", dtype=torch.float32):
```

UPD: forward pass of multi-layer version also was wrong due to the incorrect `initState, initCell` slices. Tests were passing because states were inited with zeros. *Accidentally* fixed this too

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95137
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/soulitzer
This commit is contained in:
alexdremov 2023-02-23 17:32:42 +00:00 committed by PyTorch MergeBot
parent 86efa104f5
commit b9e95158d5
6 changed files with 200 additions and 133 deletions

View File

@ -1423,7 +1423,7 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
}
#ifdef USE_MPS
if (_input.is_mps() && !bidirectional) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
return return_values;

View File

@ -23,7 +23,7 @@ std::vector<long long> getTensorShape(MPSGraphTensor* mpsTensor) {
return output_dimensions;
}
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
using namespace mps;
//Projections are not currently supported, raise an error if needed
@ -32,6 +32,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
AT_ERROR("LSTM with projections is not currently supported with MPS.");
}
TORCH_CHECK(!(!is_macos_13_or_newer() && num_layers > 1), "Multi-layer LSTM support in MPS available only on MacOS 13 onwards");
std::vector<Tensor> kernel_weights;
std::vector<Tensor> recurrent_kernel_weights;
std::vector<Tensor> biases;
@ -56,8 +58,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
std::vector<MPSGraphTensor*> outputCellStateFwdVector_;
std::vector<MPSGraphTensor*> outputZStateVector_;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -79,6 +79,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
for (size_t i = 0; i < num_layers; i += 1) {
[kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
@ -107,16 +108,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
}
MPSGraphTensor* inputTensor_ = inputTensor;
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:0
length:1
name:nil];
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:0
length:1
name:nil];
NSArray<MPSGraphTensor*>* outputs = nil;
NSMutableArray<MPSGraphTensor*>* outputStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
@ -129,6 +120,16 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
secondaryTensor:recurrentBiasList[i]
name:nil];
}
MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:i
length:1
name:nil];
MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:i
length:1
name:nil];
outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
recurrentWeight:recurrentKernelWeightsList[i]
inputWeight:kernelWeightsList[i]
@ -138,17 +139,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
descriptor:opDesc
name:nil];
stateTensor_ = [mpsGraph sliceTensor:stateTensor
dimension:0
start:i
length:1
name:nil];
cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
dimension:0
start:i
length:1
name:nil];
inputTensor_ = [outputs objectAtIndex:0];
// no need to keep a final layer output copy as it is
// returned anyway and not used in backprop
if(i != num_layers - 1) {
[layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
axis:0
name:nil]];
}
if(dropout_p>0.0 && train && (i!=num_layers-1)) {
inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
rate:dropout_p
@ -166,7 +164,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
name:nil]];
}
MPSGraphTensor* outputTensor = [outputs objectAtIndex:0];
MPSGraphTensor* outputTensor = inputTensor_;
if (batch_first) {
outputTensor = [mpsGraph transposeTensor:outputTensor
dimension:0
@ -185,8 +183,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
MPSGraphTensor* outputCellStatesFwd = [mpsGraph concatTensors:outputCellStateFwdArray
dimension:0
name:nil];
MPSGraphTensor* layersOutputs = (num_layers > 1)
? [mpsGraph concatTensors:layersOutputsList dimension:0 name:nil]
: nil;
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd};
std::vector<MPSGraphTensor*> outputTensors = {outputTensor, outputStates, outputCellStates, outputZStates, outputCellStatesFwd, layersOutputs};
newCachedGraph->inputTensors_ = inputTensors;
newCachedGraph->outputTensors_ = outputTensors;
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
@ -204,10 +205,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
Placeholder kernelWeight;
Placeholder recurrentKernelWeight;
Placeholder bias;
Placeholder recurrentBias;
Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
for (size_t i = 0; i < num_layers; i+=1) {
kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
@ -236,6 +235,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
Tensor cy = at::empty_like(hx[1], input.options());
Tensor zState = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[3])), input.options());
Tensor cellStateFwd = at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[4])), input.options());
Tensor layerOutputs = (num_layers > 1)
? at::empty(IntArrayRef(getTensorShape(cachedGraph->outputTensors_[5])), input.options())
: at::empty({ 1 }, input.options()); // not used if num_layers == 1
Placeholder outputPlaceholder0 = Placeholder(cachedGraph->outputTensors_[0], output);
Placeholder outputPlaceholder1 = Placeholder(cachedGraph->outputTensors_[1], hy);
@ -243,20 +245,25 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input
Placeholder outputPlaceholder3 = Placeholder(cachedGraph->outputTensors_[3], zState);
Placeholder outputPlaceholder4 = Placeholder(cachedGraph->outputTensors_[4], cellStateFwd);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = [@{
outputPlaceholder0.getMPSGraphTensor() : outputPlaceholder0.getMPSGraphTensorData(),
outputPlaceholder1.getMPSGraphTensor() : outputPlaceholder1.getMPSGraphTensorData(),
outputPlaceholder2.getMPSGraphTensor() : outputPlaceholder2.getMPSGraphTensorData(),
outputPlaceholder3.getMPSGraphTensor() : outputPlaceholder3.getMPSGraphTensorData(),
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData()
};
outputPlaceholder4.getMPSGraphTensor() : outputPlaceholder4.getMPSGraphTensorData(),
} mutableCopy];
if (num_layers > 1) {
Placeholder outputPlaceholder5 = Placeholder(cachedGraph->outputTensors_[5], layerOutputs);
[results setObject:outputPlaceholder5.getMPSGraphTensorData() forKey: outputPlaceholder5.getMPSGraphTensor()];
}
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return std::make_tuple(output, hy, cy, zState, cellStateFwd);
return std::make_tuple(output, hy, cy, zState, cellStateFwd, layerOutputs);
}
}
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(const Tensor& grad_y, const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, const Tensor& z_state, const Tensor& cell_state_fwd, const Tensor& input, const Tensor& layersOutputs, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
using namespace mps;
const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] {return Tensor();});
const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] {return Tensor();});
@ -287,12 +294,12 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
NSMutableArray<MPSGraphTensor*> *recurrentBiasList_ = nil;
NSMutableArray<MPSGraphTensor*> *gradOutput_ = nil;
NSMutableArray<MPSGraphTensor*> *gradRecWeights_ = nil;
NSMutableArray<MPSGraphTensor*> *gradWeights_ = nil;
NSMutableArray<MPSGraphTensor*> *gradBias_ = nil;
NSMutableArray<MPSGraphTensor*> *gradState_ = nil;
NSMutableArray<MPSGraphTensor*> *gradCellState_ = nil;
MPSGraphTensor* gradOutput_ = nil;
MPSGraphTensor* gradState_ = nil;
MPSGraphTensor* gradCellState_ = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@ -333,8 +340,22 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
MPSGraphTensor* gradientCyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_cy.scalar_type()), getMPSShape(grad_cy));
MPSGraphTensor* gradientHyTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_hy.scalar_type()), getMPSShape(grad_hy));
MPSGraphTensor* cellStateFwdTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(cell_state_fwd.scalar_type()), getMPSShape(cell_state_fwd));
MPSGraphTensor* layersOutputsTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(layersOutputs.scalar_type()), getMPSShape(layersOutputs));
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor, layersOutputsTensor};
if (batch_first) {
inputTensor = [mpsGraph transposeTensor: inputTensor
dimension: 0
withDimension: 1
name: nil];
gradientTensor = [mpsGraph transposeTensor: gradientTensor
dimension: 0
withDimension: 1
name: nil];
}
std::vector<MPSGraphTensor*> inputs = {inputTensor, stateTensor, cellStateTensor, gradientTensor, zStateTensor, cellStateFwdTensor, gradientHyTensor, gradientCyTensor};
newCachedGraph->recurrentKernelWeightsList_ = recurrentKernelWeightsList;
newCachedGraph->kernelWeightsList_ = kernelWeightsList;
newCachedGraph->biasList_ = kernelBiasList;
@ -350,7 +371,6 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
NSArray<MPSGraphTensor*>* outputs = nil;
NSMutableArray<MPSGraphTensor*>* gradOutputArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradRecWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradWeightsArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
NSMutableArray<MPSGraphTensor*>* gradBiasArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
@ -406,7 +426,23 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
length:1
name:nil];
outputs = [mpsGraph LSTMGradientsWithSourceTensor: inputTensor
MPSGraphTensor* iterationInputTensor_ = nil;
if (i == 0) {
iterationInputTensor_ = inputTensor;
} else {
iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
dimension: 0
// last element in layersOutputsTensor contains
// **inputs** for the last layer
start: i - num_layers
length: 1
name: nil];
iterationInputTensor_ = [mpsGraph squeezeTensor:iterationInputTensor_
axis:0
name: nil];
}
outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
recurrentWeight: recurrentKernelWeightsList[i]
sourceGradient: gradientTensor_
zState: zState
@ -423,22 +459,30 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
name: nil];
gradientTensor_ = [outputs objectAtIndex:0];
[gradOutputArray addObject:[outputs objectAtIndex:0]];
[gradRecWeightsArray addObject:[outputs objectAtIndex:1]];
[gradWeightsArray addObject:[outputs objectAtIndex:2]];
[gradBiasArray addObject:[outputs objectAtIndex:3]];
[gradStateArray addObject:[outputs objectAtIndex:4]];
[gradCellStateArray addObject:[outputs objectAtIndex:5]];
[gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
[gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
[gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
[gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil] atIndex:0];
[gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
}
std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
if (batch_first) {
MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
dimension: 0
withDimension: 1
name:nil];
newCachedGraph->gradOutput_ = gradientTensorTransposed;
} else {
newCachedGraph->gradOutput_ = gradientTensor_;
}
newCachedGraph->outputTensors_ = outputTensors;
newCachedGraph->gradOutput_ = gradOutputArray;
newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
newCachedGraph->gradWeights_ = gradWeightsArray;
newCachedGraph->gradBias_ = gradBiasArray;
newCachedGraph->gradState_ = gradStateArray;
newCachedGraph->gradCellState_ = gradCellStateArray;
newCachedGraph->gradState_ = [mpsGraph concatTensors:gradStateArray dimension: 0 name: nil];
newCachedGraph->gradCellState_ = [mpsGraph concatTensors:gradCellStateArray dimension: 0 name: nil];
}
return newCachedGraph;
});
@ -453,6 +497,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
Placeholder cellStateFwdPlaceholder = Placeholder(cachedGraph->inputTensors_[5], cell_state_fwd);
Placeholder gradientHyPlaceholder = Placeholder(cachedGraph->inputTensors_[6], grad_hy);
Placeholder gradientCyPlaceholder = Placeholder(cachedGraph->inputTensors_[7], grad_cy);
Placeholder layersOutputsPlaceholder = Placeholder(cachedGraph->inputTensors_[8], layersOutputs);
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
[feeds setObject:gradientPlaceholder.getMPSGraphTensorData() forKey:gradientPlaceholder.getMPSGraphTensor()];
@ -463,6 +508,7 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
[feeds setObject:cellStatePlaceholder.getMPSGraphTensorData() forKey:cellStatePlaceholder.getMPSGraphTensor()];
[feeds setObject:zStatePlaceholder.getMPSGraphTensorData() forKey:zStatePlaceholder.getMPSGraphTensor()];
[feeds setObject:cellStateFwdPlaceholder.getMPSGraphTensorData() forKey:cellStateFwdPlaceholder.getMPSGraphTensor()];
[feeds setObject:layersOutputsPlaceholder.getMPSGraphTensorData() forKey:layersOutputsPlaceholder.getMPSGraphTensor()];
NSMutableArray<MPSGraphTensor*> *kernelWeightsList = cachedGraph->kernelWeightsList_;
NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
@ -485,62 +531,55 @@ std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> lstm_mps_backward(c
}
}
Tensor output = at::empty_like(input);
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[0]);
Tensor grad_weights = at::empty_like(kernel_weights[0]);
Tensor grad_bias = at::empty((kernel_weights[0].size(0)), kernel_weights[0].options());
Tensor grad_state = at::empty_like(hx[0]);
Tensor grad_cell_state = at::empty_like(hx[1]);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensors_[0], output);
Placeholder gradRecWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[1], grad_rec_weights);
Placeholder gradWeightsPlaceholder = Placeholder(cachedGraph->outputTensors_[2], grad_weights);
Placeholder gradBiasPlaceholder = Placeholder(cachedGraph->outputTensors_[3], grad_bias);
Placeholder gradStatePlaceholder = Placeholder(cachedGraph->outputTensors_[4], grad_state);
Placeholder gradCellStatePlaceholder = Placeholder(cachedGraph->outputTensors_[5], grad_cell_state);
Tensor output_out = at::empty_like(input);
Tensor grad_state_out = at::empty_like(hx[0]);
Tensor grad_cell_state_out = at::empty_like(hx[1]);
std::vector<Tensor> grad_hx = {grad_state, grad_cell_state};
std::vector<Tensor> grad_hx = {grad_state_out, grad_cell_state_out};
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *results = [[[NSMutableDictionary alloc] init] autorelease];
NSMutableArray<MPSGraphTensor*> *gradOutputArray = cachedGraph->gradOutput_;
NSMutableArray<MPSGraphTensor*> *gradRecWeightsArray = cachedGraph->gradRecWeights_;
NSMutableArray<MPSGraphTensor*> *gradWeightsArray = cachedGraph->gradWeights_;
NSMutableArray<MPSGraphTensor*> *gradBiasArray = cachedGraph->gradBias_;
NSMutableArray<MPSGraphTensor*> *gradStateArray = cachedGraph->gradState_;
NSMutableArray<MPSGraphTensor*> *gradCellStateArray = cachedGraph->gradCellState_;
Placeholder gradOutPlaceholder;
MPSGraphTensor* gradOutput = cachedGraph->gradOutput_;
MPSGraphTensor* gradState = cachedGraph->gradState_;
MPSGraphTensor* gradCellState = cachedGraph->gradCellState_;
Placeholder gradStatePlaceholder = Placeholder(gradState, grad_state_out);
Placeholder gradCellStatePlaceholder = Placeholder(gradCellState, grad_cell_state_out);
Placeholder outputPlaceholder = Placeholder(gradOutput, output_out);
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
[results setObject:outputPlaceholder.getMPSGraphTensorData() forKey:outputPlaceholder.getMPSGraphTensor()];
Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
std::vector<Tensor> weights;
for (int i = 0; i < num_layers; i++) {
Tensor output = at::empty_like(input);
Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
Tensor grad_weights = at::empty_like(kernel_weights[i]);
Tensor grad_bias = at::empty((kernel_weights[0].size(0)), kernel_weights[0].options());
Tensor grad_state = at::empty_like(hx[0]);
Tensor grad_cell_state = at::empty_like(hx[1]);
Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
weights.push_back(grad_weights);
weights.push_back(grad_rec_weights);
if(has_biases) {
weights.push_back(grad_bias);
weights.push_back(grad_bias);
}
gradOutPlaceholder = Placeholder([gradOutputArray objectAtIndex:i], output);
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex:i], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex:i], grad_weights);
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex:i], grad_bias);
gradStatePlaceholder = Placeholder([gradStateArray objectAtIndex:i], grad_state);
gradCellStatePlaceholder = Placeholder([gradCellStateArray objectAtIndex:i], grad_cell_state);
[results setObject:gradOutPlaceholder.getMPSGraphTensorData() forKey:gradOutPlaceholder.getMPSGraphTensor()];
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
[results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
[results setObject:gradStatePlaceholder.getMPSGraphTensorData() forKey:gradStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradCellStatePlaceholder.getMPSGraphTensorData() forKey:gradCellStatePlaceholder.getMPSGraphTensor()];
[results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
[results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
}
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output, grad_hx, weights);
return std::tuple<Tensor, std::vector<Tensor>, std::vector<Tensor>> (output_out, grad_hx, weights);
}
}

View File

@ -7200,12 +7200,12 @@
# MPS LSTM implementation
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
dispatch:
MPS: _lstm_mps
autogen: _lstm_mps.out
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- func: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
dispatch:
MPS: lstm_mps_backward
autogen: lstm_mps_backward.out

View File

@ -9008,64 +9008,91 @@ class TestAdvancedIndexing(TestCaseMPS):
class TestRNNMPS(TestCaseMPS):
def test_lstm_1(self, device="mps", dtype=torch.float32):
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
torch.random.manual_seed(42)
rnn = nn.LSTM(7, 4, layers, device="cpu")
input = torch.randn(2, 3, 7, device="cpu")
hx = torch.randn(layers, 3, 4, device="cpu")
cx = torch.randn(layers, 3, 4, device="cpu")
rnn = nn.LSTM(1, 4, 2, device="cpu")
input = torch.randn(2, 3, 1, device="cpu")
hx = torch.zeros(2, 3, 4, device="cpu")
cx = torch.zeros(2, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
# test batch_first
rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
input = torch.randn(3, 2, 7, device="cpu")
hx = torch.randn(layers, 3, 4, device="cpu")
cx = torch.randn(layers, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
# test batch_first
rnn = nn.LSTM(1, 4, 2, device="cpu", batch_first=True)
input = torch.randn(3, 2, 1, device="cpu")
hx = torch.zeros(2, 3, 4, device="cpu")
cx = torch.zeros(2, 3, 4, device="cpu")
cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
rnn = rnn.to(device)
input = input.to(device)
hx = hx.to(device)
cx = cx.to(device)
output, (hn, cn) = rnn(input, (hx, cx))
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
self.assertEqual(cpu_output, output)
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
def test_lstm_backward(self, device="mps", dtype=torch.float32):
for layers in [1] if product_version < 13.0 else [1, 2, 5]:
lstm = nn.LSTM(2, 4, layers) # initialized globally for consistent parameters init
lstm.train()
@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32):
def get_results(device):
rnn = nn.LSTM(1, 4, 1, device=device)
inp = torch.randn(2, 3, 1, device=device, requires_grad=True)
hx = torch.zeros(1, 3, 4, device=device)
cx = torch.zeros(1, 3, 4, device=device)
def get_results(device, inp, hx, cx):
rnn = lstm.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
output, _ = rnn(inp, (hx, cx))
output.sum().backward()
output, _ = rnn(inp, (hx, cx))
f = output.sum()
weight_grad = rnn.weight_ih_l0.grad.clone()
input_grad = inp.grad.clone()
param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
return output, weight_grad, input_grad
input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
return output, param_grads, input_grad, hx_grad, cx_grad
inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
cpu_output, cpu_weight_grad, cpu_input_grad = get_results("cpu")
mps_output, mps_weight_grad, mps_input_grad = get_results("mps")
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
# test batch_first backward
lstm = nn.LSTM(2, 4, layers, batch_first=True)
lstm.train()
hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
self.assertEqual(cpu_hx_grad, mps_hx_grad)
self.assertEqual(cpu_cx_grad, mps_cx_grad)
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
self.assertEqual(cpu_output, mps_output)
self.assertEqual(cpu_input_grad, mps_input_grad)
self.assertEqual(cpu_weight_grad, mps_weight_grad)
def test_RNN_cell_no_broadcasting(self):
def test(cell_module, input, hx, input_size, hidden_size):

View File

@ -2568,11 +2568,11 @@
input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector<int64_t>(padding.size(), 1), false, std::vector<c10::SymInt>(padding.size(), 0), 1, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>()"
#LSTM MPS
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
output_differentiability: [True, True, True, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
output_differentiability: [True, True, True, False, False, False]
input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
- name: lstm_mps_backward(Tensor grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])

View File

@ -1109,6 +1109,7 @@ SUPPORTED_RETURN_TYPES = {
"::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
"::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",