mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Update on "[Pytorch][Bootcamp] Add fix and testing for vectorized Adadelta optimizer to handle complex numbers"
Made changes in the step function of the vectorized Adadelta optimizer to handle complex numbers as two real numbers as per 65711 on github. Differential Revision: [D31631870](https://our.internmc.facebook.com/intern/diff/D31631870/) [ghstack-poisoned]
This commit is contained in:
commit
f2aa06c17a
|
|
@ -10,7 +10,6 @@ CONFIG_TREE_DATA = [
|
||||||
]),
|
]),
|
||||||
]),
|
]),
|
||||||
# TODO: bring back libtorch test
|
# TODO: bring back libtorch test
|
||||||
("7", [X("3.6")]),
|
|
||||||
]),
|
]),
|
||||||
("cuda", [
|
("cuda", [
|
||||||
("10.2", [
|
("10.2", [
|
||||||
|
|
|
||||||
44
.circleci/config.yml
generated
44
.circleci/config.yml
generated
|
|
@ -6582,31 +6582,6 @@ workflows:
|
||||||
name: pytorch_cpp_doc_push
|
name: pytorch_cpp_doc_push
|
||||||
requires:
|
requires:
|
||||||
- pytorch_cpp_doc_build
|
- pytorch_cpp_doc_build
|
||||||
- pytorch_linux_build:
|
|
||||||
name: pytorch_linux_xenial_py3_6_gcc7_build
|
|
||||||
requires:
|
|
||||||
- "docker-pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
filters:
|
|
||||||
branches:
|
|
||||||
only:
|
|
||||||
- master
|
|
||||||
- /ci-all\/.*/
|
|
||||||
- /release\/.*/
|
|
||||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-build"
|
|
||||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
- pytorch_linux_test:
|
|
||||||
name: pytorch_linux_xenial_py3_6_gcc7_test
|
|
||||||
requires:
|
|
||||||
- pytorch_linux_xenial_py3_6_gcc7_build
|
|
||||||
filters:
|
|
||||||
branches:
|
|
||||||
only:
|
|
||||||
- master
|
|
||||||
- /ci-all\/.*/
|
|
||||||
- /release\/.*/
|
|
||||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-test"
|
|
||||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
resource_class: large
|
|
||||||
- pytorch_linux_build:
|
- pytorch_linux_build:
|
||||||
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
||||||
requires:
|
requires:
|
||||||
|
|
@ -8334,9 +8309,6 @@ workflows:
|
||||||
only: /.*/
|
only: /.*/
|
||||||
tags:
|
tags:
|
||||||
only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
|
only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
|
||||||
- docker_build_job:
|
|
||||||
name: "docker-pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
image_name: "pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
when: << pipeline.parameters.run_build >>
|
when: << pipeline.parameters.run_build >>
|
||||||
master_build:
|
master_build:
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -8352,19 +8324,6 @@ workflows:
|
||||||
- pytorch_cpp_doc_build:
|
- pytorch_cpp_doc_build:
|
||||||
requires:
|
requires:
|
||||||
- pytorch_linux_xenial_py3_6_gcc5_4_build
|
- pytorch_linux_xenial_py3_6_gcc5_4_build
|
||||||
- pytorch_linux_build:
|
|
||||||
name: pytorch_linux_xenial_py3_6_gcc7_build
|
|
||||||
requires:
|
|
||||||
- "docker-pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-build"
|
|
||||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
- pytorch_linux_test:
|
|
||||||
name: pytorch_linux_xenial_py3_6_gcc7_test
|
|
||||||
requires:
|
|
||||||
- pytorch_linux_xenial_py3_6_gcc7_build
|
|
||||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-test"
|
|
||||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
resource_class: large
|
|
||||||
- pytorch_linux_build:
|
- pytorch_linux_build:
|
||||||
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
||||||
requires:
|
requires:
|
||||||
|
|
@ -8474,9 +8433,6 @@ workflows:
|
||||||
- docker_build_job:
|
- docker_build_job:
|
||||||
name: "docker-pytorch-linux-xenial-py3.6-gcc5.4"
|
name: "docker-pytorch-linux-xenial-py3.6-gcc5.4"
|
||||||
image_name: "pytorch-linux-xenial-py3.6-gcc5.4"
|
image_name: "pytorch-linux-xenial-py3.6-gcc5.4"
|
||||||
- docker_build_job:
|
|
||||||
name: "docker-pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
image_name: "pytorch-linux-xenial-py3.6-gcc7"
|
|
||||||
when: << pipeline.parameters.run_master_build >>
|
when: << pipeline.parameters.run_master_build >>
|
||||||
ecr_gc:
|
ecr_gc:
|
||||||
triggers:
|
triggers:
|
||||||
|
|
|
||||||
4
.github/generated-ciflow-ruleset.json
generated
vendored
4
.github/generated-ciflow-ruleset.json
generated
vendored
|
|
@ -17,6 +17,7 @@
|
||||||
"linux-xenial-py3.6-clang7-asan",
|
"linux-xenial-py3.6-clang7-asan",
|
||||||
"linux-xenial-py3.6-clang7-onnx",
|
"linux-xenial-py3.6-clang7-onnx",
|
||||||
"linux-xenial-py3.6-gcc5.4",
|
"linux-xenial-py3.6-gcc5.4",
|
||||||
|
"linux-xenial-py3.6-gcc7",
|
||||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||||
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
||||||
|
|
@ -36,6 +37,7 @@
|
||||||
"linux-xenial-py3.6-clang7-asan",
|
"linux-xenial-py3.6-clang7-asan",
|
||||||
"linux-xenial-py3.6-clang7-onnx",
|
"linux-xenial-py3.6-clang7-onnx",
|
||||||
"linux-xenial-py3.6-gcc5.4",
|
"linux-xenial-py3.6-gcc5.4",
|
||||||
|
"linux-xenial-py3.6-gcc7",
|
||||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||||
"win-vs2019-cpu-py3"
|
"win-vs2019-cpu-py3"
|
||||||
|
|
@ -62,6 +64,7 @@
|
||||||
"linux-xenial-py3.6-clang7-asan",
|
"linux-xenial-py3.6-clang7-asan",
|
||||||
"linux-xenial-py3.6-clang7-onnx",
|
"linux-xenial-py3.6-clang7-onnx",
|
||||||
"linux-xenial-py3.6-gcc5.4",
|
"linux-xenial-py3.6-gcc5.4",
|
||||||
|
"linux-xenial-py3.6-gcc7",
|
||||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||||
"win-vs2019-cpu-py3",
|
"win-vs2019-cpu-py3",
|
||||||
"win-vs2019-cuda11.3-py3"
|
"win-vs2019-cuda11.3-py3"
|
||||||
|
|
@ -87,6 +90,7 @@
|
||||||
"linux-xenial-py3.6-clang7-asan",
|
"linux-xenial-py3.6-clang7-asan",
|
||||||
"linux-xenial-py3.6-clang7-onnx",
|
"linux-xenial-py3.6-clang7-onnx",
|
||||||
"linux-xenial-py3.6-gcc5.4",
|
"linux-xenial-py3.6-gcc5.4",
|
||||||
|
"linux-xenial-py3.6-gcc7",
|
||||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||||
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
||||||
|
|
|
||||||
11
.github/scripts/generate_ci_workflows.py
vendored
11
.github/scripts/generate_ci_workflows.py
vendored
|
|
@ -274,6 +274,17 @@ LINUX_WORKFLOWS = [
|
||||||
labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}
|
labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
CIWorkflow(
|
||||||
|
arch="linux",
|
||||||
|
build_environment="linux-xenial-py3.6-gcc7",
|
||||||
|
docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc7",
|
||||||
|
test_runner_type=LINUX_CPU_TEST_RUNNER,
|
||||||
|
num_test_shards=2,
|
||||||
|
ciflow_config=CIFlowConfig(
|
||||||
|
run_on_canary=True,
|
||||||
|
labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}
|
||||||
|
),
|
||||||
|
),
|
||||||
# ParallelTBB does not have a maintainer and is currently flaky
|
# ParallelTBB does not have a maintainer and is currently flaky
|
||||||
# CIWorkflow(
|
# CIWorkflow(
|
||||||
# arch="linux",
|
# arch="linux",
|
||||||
|
|
|
||||||
517
.github/workflows/generated-linux-xenial-py3.6-gcc7.yml
generated
vendored
Normal file
517
.github/workflows/generated-linux-xenial-py3.6-gcc7.yml
generated
vendored
Normal file
|
|
@ -0,0 +1,517 @@
|
||||||
|
# @generated DO NOT EDIT MANUALLY
|
||||||
|
# Template is at: .github/templates/linux_ci_workflow.yml.j2
|
||||||
|
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||||
|
name: linux-xenial-py3.6-gcc7
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, synchronize, reopened, unassigned]
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
- release/*
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
BUILD_ENVIRONMENT: linux-xenial-py3.6-gcc7
|
||||||
|
DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7
|
||||||
|
SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
|
||||||
|
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||||
|
TORCH_CUDA_ARCH_LIST: 5.2
|
||||||
|
IN_CI: 1
|
||||||
|
IS_GHA: 1
|
||||||
|
# This is used for the phase of adding wheel tests only, will be removed once completed
|
||||||
|
IN_WHEEL_TEST: 1
|
||||||
|
# Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh
|
||||||
|
CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts
|
||||||
|
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||||
|
PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }}
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
AWS_DEFAULT_REGION: us-east-1
|
||||||
|
CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||||
|
concurrency:
|
||||||
|
group: linux-xenial-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
|
||||||
|
ciflow_should_run:
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
env:
|
||||||
|
IS_PROBOT_TRIGGER_EVENT: ${{ (github.event.action == 'unassigned') && (github.event.assigneed.login == 'pytorchbot') }}
|
||||||
|
LABEL_CONDITIONS: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') }}
|
||||||
|
LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }}
|
||||||
|
if: ${{ (github.repository_owner == 'pytorch') && (
|
||||||
|
(github.event_name == 'push') ||
|
||||||
|
(github.event_name == 'schedule') ||
|
||||||
|
(contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) ||
|
||||||
|
((github.event_name == 'pull_request' && github.event.action != 'unassigned') && !contains(join(github.event.pull_request.labels.*.name), 'ciflow/')))
|
||||||
|
}}
|
||||||
|
steps:
|
||||||
|
- name: noop
|
||||||
|
run: echo running ciflow_should_run
|
||||||
|
- name: print labels
|
||||||
|
run: echo "${LABELS}"
|
||||||
|
|
||||||
|
build:
|
||||||
|
runs-on: linux.2xlarge
|
||||||
|
needs: [ciflow_should_run]
|
||||||
|
env:
|
||||||
|
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-build
|
||||||
|
outputs:
|
||||||
|
docker_image: ${{ steps.calculate-tag.outputs.docker_image }}
|
||||||
|
steps:
|
||||||
|
- name: Display EC2 information
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
function get_ec2_metadata() {
|
||||||
|
# Pulled from instance metadata endpoint for EC2
|
||||||
|
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||||
|
category=$1
|
||||||
|
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||||
|
}
|
||||||
|
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||||
|
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||||
|
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||||
|
- name: Log in to ECR
|
||||||
|
env:
|
||||||
|
AWS_RETRY_MODE: standard
|
||||||
|
AWS_MAX_ATTEMPTS: 5
|
||||||
|
run: |
|
||||||
|
aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh
|
||||||
|
bash /tmp/ecr-login.sh
|
||||||
|
rm /tmp/ecr-login.sh
|
||||||
|
- name: Chown workspace
|
||||||
|
env:
|
||||||
|
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||||
|
run: |
|
||||||
|
retry () {
|
||||||
|
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||||
|
}
|
||||||
|
retry docker pull "${ALPINE_IMAGE}"
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --pull=never --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Clean workspace
|
||||||
|
run: |
|
||||||
|
rm -rf "${GITHUB_WORKSPACE:?}/*"
|
||||||
|
rm -f ~/.ssh/authorized_keys
|
||||||
|
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||||
|
uses: seemethere/add-github-ssh-key@v1
|
||||||
|
with:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Preserve github env variables for use in docker
|
||||||
|
run: |
|
||||||
|
env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||||
|
- name: Checkout PyTorch
|
||||||
|
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||||
|
with:
|
||||||
|
# deep clone, to allow use of git merge-base
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: recursive
|
||||||
|
- name: Calculate docker image tag
|
||||||
|
id: calculate-tag
|
||||||
|
run: |
|
||||||
|
DOCKER_TAG=$(git rev-parse HEAD:.circleci/docker)
|
||||||
|
echo "DOCKER_TAG=${DOCKER_TAG}" >> "${GITHUB_ENV}"
|
||||||
|
echo "DOCKER_IMAGE=${DOCKER_IMAGE_BASE}:${DOCKER_TAG}" >> "${GITHUB_ENV}"
|
||||||
|
echo "::set-output name=docker_tag::${DOCKER_TAG}"
|
||||||
|
echo "::set-output name=docker_image::${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"
|
||||||
|
- name: Check if image should be built
|
||||||
|
id: check
|
||||||
|
env:
|
||||||
|
BASE_REVISION: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||||
|
run: |
|
||||||
|
set -x
|
||||||
|
# Check if image already exists, if it does then skip building it
|
||||||
|
if docker manifest inspect "${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"; then
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
if [[ "$BASE_REVISION" = "$(git rev-parse HEAD)" ]]; then
|
||||||
|
# if we're on the base branch then use the parent commit
|
||||||
|
MERGE_BASE=$(git rev-parse HEAD~)
|
||||||
|
else
|
||||||
|
# otherwise we're on a PR, so use the most recent base commit
|
||||||
|
MERGE_BASE=$(git merge-base HEAD "$BASE_REVISION")
|
||||||
|
fi
|
||||||
|
# Covers the case where a previous tag doesn't exist for the tree
|
||||||
|
# this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly
|
||||||
|
if ! git rev-parse "$MERGE_BASE:.circleci/docker"; then
|
||||||
|
echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
PREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker")
|
||||||
|
# If no image exists but the hash is the same as the previous hash then we should error out here
|
||||||
|
if [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then
|
||||||
|
echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch"
|
||||||
|
echo " contact the PyTorch team to restore the original images"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo ::set-output name=rebuild::yes
|
||||||
|
- name: Build and push docker image
|
||||||
|
if: ${{ steps.check.outputs.rebuild }}
|
||||||
|
env:
|
||||||
|
DOCKER_SKIP_S3_UPLOAD: 1
|
||||||
|
working-directory: .circleci/docker
|
||||||
|
run: |
|
||||||
|
export IMAGE_NAME=${DOCKER_IMAGE_BASE#308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/}
|
||||||
|
./build_docker.sh
|
||||||
|
- name: Pull Docker image
|
||||||
|
run: |
|
||||||
|
retry () {
|
||||||
|
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||||
|
}
|
||||||
|
retry docker pull "${DOCKER_IMAGE}"
|
||||||
|
- name: Parse ref
|
||||||
|
id: parse-ref
|
||||||
|
run: .github/scripts/parse_ref.py
|
||||||
|
- name: Build
|
||||||
|
env:
|
||||||
|
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||||
|
run: |
|
||||||
|
# detached container should get cleaned up by teardown_ec2_linux
|
||||||
|
container_name=$(docker run \
|
||||||
|
-e BUILD_ENVIRONMENT \
|
||||||
|
-e JOB_BASE_NAME \
|
||||||
|
-e MAX_JOBS="$(nproc --ignore=2)" \
|
||||||
|
-e AWS_DEFAULT_REGION \
|
||||||
|
-e IS_GHA \
|
||||||
|
-e CIRCLE_PR_NUMBER \
|
||||||
|
-e CIRCLE_SHA1 \
|
||||||
|
-e CIRCLE_BRANCH \
|
||||||
|
-e GITHUB_RUN_ID \
|
||||||
|
-e SCCACHE_BUCKET \
|
||||||
|
-e XLA_CLANG_CACHE_S3_BUCKET_NAME \
|
||||||
|
-e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
|
||||||
|
-e SKIP_SCCACHE_INITIALIZATION=1 \
|
||||||
|
-e TORCH_CUDA_ARCH_LIST \
|
||||||
|
-e PR_LABELS \
|
||||||
|
-e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,github.com,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \
|
||||||
|
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--tty \
|
||||||
|
--detach \
|
||||||
|
--user jenkins \
|
||||||
|
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||||
|
-w /var/lib/jenkins/workspace \
|
||||||
|
"${DOCKER_IMAGE}"
|
||||||
|
)
|
||||||
|
docker exec -t "${container_name}" sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh'
|
||||||
|
- name: Display and upload binary build size statistics (Click Me)
|
||||||
|
# temporary hack: set CIRCLE_* vars, until we update
|
||||||
|
# tools/stats/print_test_stats.py to natively support GitHub Actions
|
||||||
|
env:
|
||||||
|
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
|
||||||
|
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||||
|
CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
|
||||||
|
CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
|
||||||
|
run: |
|
||||||
|
COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
|
||||||
|
export COMMIT_TIME
|
||||||
|
pip3 install requests==2.26 boto3==1.16.34
|
||||||
|
python3 -m tools.stats.upload_binary_size_to_scuba || exit 0
|
||||||
|
- name: Chown workspace
|
||||||
|
run: |
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Archive artifacts into zip
|
||||||
|
run: |
|
||||||
|
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
|
||||||
|
- uses: seemethere/upload-artifact-s3@v3
|
||||||
|
name: Store PyTorch Build Artifacts on S3
|
||||||
|
with:
|
||||||
|
name: ${{ env.BUILD_ENVIRONMENT }}
|
||||||
|
retention-days: 14
|
||||||
|
if-no-files-found: error
|
||||||
|
path:
|
||||||
|
artifacts.zip
|
||||||
|
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||||
|
# Always hold for active ssh sessions
|
||||||
|
if: always()
|
||||||
|
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||||
|
- name: Chown workspace
|
||||||
|
if: always()
|
||||||
|
env:
|
||||||
|
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||||
|
run: |
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Kill containers, clean up images
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
# ignore expansion of "docker ps -q" since it could be empty
|
||||||
|
# shellcheck disable=SC2046
|
||||||
|
docker stop $(docker ps -q) || true
|
||||||
|
# Prune all of the docker images
|
||||||
|
docker system prune -af
|
||||||
|
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||||
|
# Always hold for active ssh sessions
|
||||||
|
if: always()
|
||||||
|
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||||
|
- name: Clean up docker images
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
# Prune all of the docker images
|
||||||
|
docker system prune -af
|
||||||
|
|
||||||
|
generate-test-matrix:
|
||||||
|
runs-on: ubuntu-18.04
|
||||||
|
needs: [ciflow_should_run]
|
||||||
|
env:
|
||||||
|
TEST_RUNNER_TYPE: linux.2xlarge
|
||||||
|
ENABLE_DISTRIBUTED_TEST: 1
|
||||||
|
ENABLE_JIT_LEGACY_TEST: ''
|
||||||
|
ENABLE_MULTIGPU_TEST: ''
|
||||||
|
ENABLE_NOGPU_NO_AVX_TEST: ''
|
||||||
|
ENABLE_NOGPU_NO_AVX2_TEST: ''
|
||||||
|
ENABLE_SLOW_TEST: ''
|
||||||
|
ENABLE_DOCS_TEST: ''
|
||||||
|
ENABLE_BACKWARDS_COMPAT_TEST: ''
|
||||||
|
ENABLE_XLA_TEST: ''
|
||||||
|
ENABLE_NOARCH_TEST: ''
|
||||||
|
NUM_TEST_SHARDS: 2
|
||||||
|
MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu
|
||||||
|
NOGPU_RUNNER_TYPE: linux.2xlarge
|
||||||
|
PR_BODY: ${{ github.event.pull_request.body }}
|
||||||
|
outputs:
|
||||||
|
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||||
|
render-matrix: ${{ steps.set-matrix.outputs.render-matrix }}
|
||||||
|
ignore-disabled-issues: ${{ steps.set-matrix.outputs.ignore-disabled-issues }}
|
||||||
|
container:
|
||||||
|
image: python:3.9
|
||||||
|
steps:
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install typing-extensions==3.10
|
||||||
|
- name: Clone pytorch/pytorch
|
||||||
|
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||||
|
- name: Generating test matrix
|
||||||
|
id: set-matrix
|
||||||
|
run: .github/scripts/generate_pytorch_test_matrix.py
|
||||||
|
|
||||||
|
test:
|
||||||
|
needs: [build, generate-test-matrix, ciflow_should_run]
|
||||||
|
strategy:
|
||||||
|
matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }}
|
||||||
|
fail-fast: false
|
||||||
|
runs-on: ${{ matrix.runner }}
|
||||||
|
env:
|
||||||
|
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
||||||
|
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-test
|
||||||
|
TEST_CONFIG: ${{ matrix.config }}
|
||||||
|
SHARD_NUMBER: ${{ matrix.shard }}
|
||||||
|
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
|
||||||
|
PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }}
|
||||||
|
steps:
|
||||||
|
- name: Display EC2 information
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
function get_ec2_metadata() {
|
||||||
|
# Pulled from instance metadata endpoint for EC2
|
||||||
|
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||||
|
category=$1
|
||||||
|
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||||
|
}
|
||||||
|
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||||
|
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||||
|
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||||
|
- name: Log in to ECR
|
||||||
|
env:
|
||||||
|
AWS_RETRY_MODE: standard
|
||||||
|
AWS_MAX_ATTEMPTS: 5
|
||||||
|
run: |
|
||||||
|
aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh
|
||||||
|
bash /tmp/ecr-login.sh
|
||||||
|
rm /tmp/ecr-login.sh
|
||||||
|
- name: Chown workspace
|
||||||
|
env:
|
||||||
|
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||||
|
run: |
|
||||||
|
retry () {
|
||||||
|
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||||
|
}
|
||||||
|
retry docker pull "${ALPINE_IMAGE}"
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --pull=never --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Clean workspace
|
||||||
|
run: |
|
||||||
|
rm -rf "${GITHUB_WORKSPACE:?}/*"
|
||||||
|
rm -f ~/.ssh/authorized_keys
|
||||||
|
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||||
|
uses: seemethere/add-github-ssh-key@v1
|
||||||
|
with:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Preserve github env variables for use in docker
|
||||||
|
run: |
|
||||||
|
env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||||
|
- name: Checkout PyTorch
|
||||||
|
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||||
|
with:
|
||||||
|
# deep clone, to allow use of git merge-base
|
||||||
|
fetch-depth: 0
|
||||||
|
submodules: recursive
|
||||||
|
- name: Pull Docker image
|
||||||
|
run: |
|
||||||
|
retry () {
|
||||||
|
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||||
|
}
|
||||||
|
retry docker pull "${DOCKER_IMAGE}"
|
||||||
|
- name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG
|
||||||
|
if: ${{ contains(env.BUILD_ENVIRONMENT, 'cuda') && !contains(matrix.config, 'nogpu') }}
|
||||||
|
run: |
|
||||||
|
bash .github/scripts/install_nvidia_utils_linux.sh
|
||||||
|
echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}"
|
||||||
|
- name: Determine shm-size
|
||||||
|
run: |
|
||||||
|
shm_size="1g"
|
||||||
|
case "${BUILD_ENVIRONMENT}" in
|
||||||
|
*cuda*)
|
||||||
|
shm_size="2g"
|
||||||
|
;;
|
||||||
|
*rocm*)
|
||||||
|
shm_size="8g"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
echo "SHM_SIZE=${shm_size}" >> "${GITHUB_ENV}"
|
||||||
|
- uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b
|
||||||
|
name: Download PyTorch Build Artifacts
|
||||||
|
with:
|
||||||
|
name: ${{ env.BUILD_ENVIRONMENT }}
|
||||||
|
- name: Unzip artifacts
|
||||||
|
run: |
|
||||||
|
unzip -o artifacts.zip
|
||||||
|
- name: Output disk space left
|
||||||
|
run: |
|
||||||
|
sudo df -H
|
||||||
|
- name: Parse ref
|
||||||
|
id: parse-ref
|
||||||
|
run: .github/scripts/parse_ref.py
|
||||||
|
- name: Test
|
||||||
|
env:
|
||||||
|
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||||
|
# Time out the test phase after 240 minutes
|
||||||
|
timeout-minutes: 240
|
||||||
|
run: |
|
||||||
|
set -x
|
||||||
|
|
||||||
|
if [[ $TEST_CONFIG == 'multigpu' ]]; then
|
||||||
|
TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
|
||||||
|
elif [[ $BUILD_ENVIRONMENT == *onnx* ]]; then
|
||||||
|
TEST_COMMAND=.jenkins/caffe2/test.sh
|
||||||
|
else
|
||||||
|
TEST_COMMAND=.jenkins/pytorch/test.sh
|
||||||
|
fi
|
||||||
|
# detached container should get cleaned up by teardown_ec2_linux
|
||||||
|
# TODO: Stop building test binaries as part of the build phase
|
||||||
|
# Used for GPU_FLAG since that doesn't play nice
|
||||||
|
# shellcheck disable=SC2086
|
||||||
|
container_name=$(docker run \
|
||||||
|
${GPU_FLAG:-} \
|
||||||
|
-e BUILD_ENVIRONMENT \
|
||||||
|
-e PR_NUMBER \
|
||||||
|
-e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
|
||||||
|
-e GITHUB_ACTIONS \
|
||||||
|
-e IN_CI \
|
||||||
|
-e IS_GHA \
|
||||||
|
-e CIRCLE_BRANCH \
|
||||||
|
-e CIRCLE_SHA1 \
|
||||||
|
-e CIRCLE_PR_NUMBER \
|
||||||
|
-e AWS_DEFAULT_REGION \
|
||||||
|
-e IN_WHEEL_TEST \
|
||||||
|
-e SHARD_NUMBER \
|
||||||
|
-e JOB_BASE_NAME \
|
||||||
|
-e TEST_CONFIG \
|
||||||
|
-e NUM_TEST_SHARDS \
|
||||||
|
-e PYTORCH_IGNORE_DISABLED_ISSUES \
|
||||||
|
-e PR_LABELS \
|
||||||
|
-e MAX_JOBS="$(nproc --ignore=2)" \
|
||||||
|
-e SCCACHE_BUCKET \
|
||||||
|
-e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,github.com,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \
|
||||||
|
-e XLA_CLANG_CACHE_S3_BUCKET_NAME \
|
||||||
|
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||||
|
--ulimit stack=10485760:83886080 \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--shm-size="${SHM_SIZE}" \
|
||||||
|
--tty \
|
||||||
|
--detach \
|
||||||
|
--name="${container_name}" \
|
||||||
|
--user jenkins \
|
||||||
|
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||||
|
-w /var/lib/jenkins/workspace \
|
||||||
|
"${DOCKER_IMAGE}"
|
||||||
|
)
|
||||||
|
docker exec -t "${container_name}" sh -c "sudo chown -R jenkins . && pip install dist/*.whl && ${TEST_COMMAND}"
|
||||||
|
- name: Chown workspace
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Install render_test_results dependencies
|
||||||
|
if: always()
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install junitparser==2.1.1 rich==10.9.0
|
||||||
|
- name: "[[ Click me for rendered test results (useful for finding failing tests) ]]"
|
||||||
|
if: always()
|
||||||
|
shell: bash
|
||||||
|
# Encoding is weird on windows, just try to default to utf-8 if possible
|
||||||
|
env:
|
||||||
|
PYTHONIOENCODING: "utf-8"
|
||||||
|
run: |
|
||||||
|
python3 tools/render_junit.py test/
|
||||||
|
- name: Zip test reports for upload
|
||||||
|
if: always()
|
||||||
|
env:
|
||||||
|
FILE_SUFFIX: '${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}'
|
||||||
|
run: |
|
||||||
|
# Remove any previous test reports if they exist
|
||||||
|
rm -f test-reports-*.zip
|
||||||
|
zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml'
|
||||||
|
- uses: seemethere/upload-artifact-s3@v3
|
||||||
|
name: Store Test Reports on S3
|
||||||
|
if: always()
|
||||||
|
with:
|
||||||
|
retention-days: 14
|
||||||
|
if-no-files-found: error
|
||||||
|
path:
|
||||||
|
test-reports-*.zip
|
||||||
|
- name: Display and upload test statistics (Click Me)
|
||||||
|
if: always()
|
||||||
|
# temporary hack: set CIRCLE_* vars, until we update
|
||||||
|
# tools/stats/print_test_stats.py to natively support GitHub Actions
|
||||||
|
env:
|
||||||
|
AWS_DEFAULT_REGION: us-east-1
|
||||||
|
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||||
|
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-test
|
||||||
|
CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||||
|
CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||||
|
CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
|
||||||
|
CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install -r requirements.txt
|
||||||
|
python3 -m pip install boto3==1.16.34
|
||||||
|
python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test
|
||||||
|
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||||
|
# Always hold for active ssh sessions
|
||||||
|
if: always()
|
||||||
|
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||||
|
- name: Chown workspace
|
||||||
|
if: always()
|
||||||
|
env:
|
||||||
|
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||||
|
run: |
|
||||||
|
# Ensure the working directory gets chowned back to the current user
|
||||||
|
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||||
|
- name: Kill containers, clean up images
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
# ignore expansion of "docker ps -q" since it could be empty
|
||||||
|
# shellcheck disable=SC2046
|
||||||
|
docker stop $(docker ps -q) || true
|
||||||
|
# Prune all of the docker images
|
||||||
|
docker system prune -af
|
||||||
|
|
@ -327,9 +327,11 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||||
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||||
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||||
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||||
|
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||||
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND
|
if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND
|
||||||
CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8)
|
CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8)
|
||||||
set(GCC_7 True)
|
set(GCC_7 True)
|
||||||
|
|
|
||||||
|
|
@ -548,6 +548,7 @@ static void check_shape_forward(const at::Tensor& input,
|
||||||
", expected input", input.sizes(), " to have ",
|
", expected input", input.sizes(), " to have ",
|
||||||
(weight_sizes[1] * groups), " channels, but got ", input.size(1),
|
(weight_sizes[1] * groups), " channels, but got ", input.size(1),
|
||||||
" channels instead");
|
" channels instead");
|
||||||
|
|
||||||
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
|
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
|
||||||
"Given weight of size ", weight_sizes,
|
"Given weight of size ", weight_sizes,
|
||||||
", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
|
", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
|
||||||
|
|
@ -850,7 +851,7 @@ at::Tensor _convolution(
|
||||||
|
|
||||||
check_shape_forward(input, weight_sizes, bias, params);
|
check_shape_forward(input, weight_sizes, bias, params);
|
||||||
|
|
||||||
if (input.size(0) == 0) {
|
if (input.size(0) == 0 || input.size(1) == 0) {
|
||||||
// don't send empty inputs through backends
|
// don't send empty inputs through backends
|
||||||
// but need to compute correct output size first and set up history for params
|
// but need to compute correct output size first and set up history for params
|
||||||
std::vector<int64_t> o;
|
std::vector<int64_t> o;
|
||||||
|
|
@ -862,6 +863,9 @@ at::Tensor _convolution(
|
||||||
params.output_padding, params.stride, params.dilation,
|
params.output_padding, params.stride, params.dilation,
|
||||||
params.groups);
|
params.groups);
|
||||||
}
|
}
|
||||||
|
if (input.size(1) == 0) {
|
||||||
|
o[input_channels_dim] = 0;
|
||||||
|
}
|
||||||
if (input_is_mkldnn && weight.is_mkldnn()) {
|
if (input_is_mkldnn && weight.is_mkldnn()) {
|
||||||
// mkldnn will error on the below 0-dim handling code
|
// mkldnn will error on the below 0-dim handling code
|
||||||
return empty_mkldnn(
|
return empty_mkldnn(
|
||||||
|
|
@ -871,10 +875,12 @@ at::Tensor _convolution(
|
||||||
input.options().device_opt(),
|
input.options().device_opt(),
|
||||||
input.options().pinned_memory_opt());
|
input.options().pinned_memory_opt());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weight_view = at::_unsafe_view(weight, -1);
|
auto weight_view = at::_unsafe_view(weight, -1);
|
||||||
auto out = input*weight_view[0];
|
auto out = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
|
||||||
if (bias.defined())
|
if (bias.defined()) {
|
||||||
out.add_(bias[0]);
|
out.add_(bias[0]);
|
||||||
|
}
|
||||||
return out.view(o);
|
return out.view(o);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,20 +52,17 @@ static inline void slow_conv2d_shape_check(
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t ndim = input.dim();
|
const int64_t ndim = input.dim();
|
||||||
const int64_t dim_batch = 0;
|
|
||||||
const int64_t dim_planes = 1;
|
const int64_t dim_planes = 1;
|
||||||
const int64_t dim_height = 2;
|
const int64_t dim_height = 2;
|
||||||
const int64_t dim_width = 3;
|
const int64_t dim_width = 3;
|
||||||
|
|
||||||
// Allow for empty batch size but not other dimensions
|
// Allow for empty batch size and channel size but not other dimensions
|
||||||
bool valid_empty = ndim == 4 && input.size(dim_batch) == 0 &&
|
TORCH_CHECK(ndim == 4, "Expected 4D input tensor, but got: ", input.sizes());
|
||||||
input.size(dim_planes) != 0 && input.size(dim_height) != 0 &&
|
for (int64_t dim = 2; dim < ndim; ++dim) {
|
||||||
input.size(dim_width) != 0;
|
TORCH_CHECK(input.size(dim) != 0,
|
||||||
|
"Expected non-zero size for input dimension ", dim,
|
||||||
TORCH_CHECK(
|
", but got input shape: ", input.sizes(), ". Only the batch and channel dimensions support size 0.");
|
||||||
(input.numel() > 0 || valid_empty) && ndim == 4,
|
}
|
||||||
"non-empty 4D input tensor expected but got: ",
|
|
||||||
input.sizes());
|
|
||||||
|
|
||||||
const int64_t input_height = input.size(dim_height);
|
const int64_t input_height = input.size(dim_height);
|
||||||
const int64_t input_width = input.size(dim_width);
|
const int64_t input_width = input.size(dim_width);
|
||||||
|
|
@ -109,8 +106,10 @@ static inline void slow_conv2d_shape_check(
|
||||||
if (weight.dim() == 2) {
|
if (weight.dim() == 2) {
|
||||||
n_input_plane /= (kernel_height * kernel_width);
|
n_input_plane /= (kernel_height * kernel_width);
|
||||||
}
|
}
|
||||||
|
if (input.size(1) != 0) {
|
||||||
check_dim_size(input, ndim, dim_planes, n_input_plane);
|
check_dim_size(input, ndim, dim_planes, n_input_plane);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (grad_output.defined()) {
|
if (grad_output.defined()) {
|
||||||
if (weight.defined()) {
|
if (weight.defined()) {
|
||||||
|
|
@ -529,6 +528,7 @@ std::tuple<Tensor, Tensor> slow_conv2d_forward_cpu(
|
||||||
padding,
|
padding,
|
||||||
output,
|
output,
|
||||||
finput);
|
finput);
|
||||||
|
|
||||||
return std::make_tuple(output, finput);
|
return std::make_tuple(output, finput);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -559,6 +559,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cpu(
|
||||||
at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
|
at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if (grad_weight.defined()) {
|
if (grad_weight.defined()) {
|
||||||
grad_weight.resize_(weight.sizes());
|
grad_weight.resize_(weight.sizes());
|
||||||
grad_weight.zero_();
|
grad_weight.zero_();
|
||||||
|
|
|
||||||
|
|
@ -82,11 +82,11 @@ static void BM_deep_wide_static(benchmark::State& state) {
|
||||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||||
auto wide = torch::randn({batch_size, num_features});
|
auto wide = torch::randn({batch_size, num_features});
|
||||||
|
|
||||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||||
|
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -104,11 +104,11 @@ static void BM_deep_wide_static_threaded(benchmark::State& state) {
|
||||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||||
auto wide = torch::randn({batch_size, num_features});
|
auto wide = torch::randn({batch_size, num_features});
|
||||||
|
|
||||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||||
|
|
||||||
sr(inputs);
|
sr(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
sr(inputs);
|
sr(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -118,11 +118,11 @@ static void BM_leaky_relu_const(benchmark::State& state) {
|
||||||
|
|
||||||
const int batch_size = state.range(0);
|
const int batch_size = state.range(0);
|
||||||
auto data = torch::randn({batch_size, num_features});
|
auto data = torch::randn({batch_size, num_features});
|
||||||
std::vector<at::Tensor> inputs({data});
|
std::vector<c10::IValue> inputs({data});
|
||||||
|
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -133,11 +133,11 @@ static void BM_leaky_relu(benchmark::State& state) {
|
||||||
const int batch_size = state.range(0);
|
const int batch_size = state.range(0);
|
||||||
auto neg_slope = torch::randn(1);
|
auto neg_slope = torch::randn(1);
|
||||||
auto data = torch::randn({batch_size, num_features});
|
auto data = torch::randn({batch_size, num_features});
|
||||||
std::vector<at::Tensor> inputs({data, neg_slope[0]});
|
std::vector<c10::IValue> inputs({data, neg_slope[0]});
|
||||||
|
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -150,11 +150,11 @@ static void BM_signed_log1p(benchmark::State& state) {
|
||||||
|
|
||||||
const int num_elements = state.range(0);
|
const int num_elements = state.range(0);
|
||||||
auto data = torch::randn({num_elements});
|
auto data = torch::randn({num_elements});
|
||||||
std::vector<at::Tensor> inputs({data});
|
std::vector<c10::IValue> inputs({data});
|
||||||
|
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,11 +170,11 @@ static void BM_long_static_memory_optimization(benchmark::State& state) {
|
||||||
auto a = torch::randn({N, N});
|
auto a = torch::randn({N, N});
|
||||||
auto b = torch::randn({N, N});
|
auto b = torch::randn({N, N});
|
||||||
auto c = torch::randn({N, N});
|
auto c = torch::randn({N, N});
|
||||||
std::vector<at::Tensor> inputs({a, b, c});
|
std::vector<c10::IValue> inputs({a, b, c});
|
||||||
|
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
for (auto _ : state) {
|
for (auto _ : state) {
|
||||||
smod(inputs);
|
smod(inputs, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -755,9 +755,9 @@ TEST(StaticRuntime, LongModel) {
|
||||||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||||
|
|
||||||
// run static runtime
|
// run static runtime
|
||||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
std::vector<c10::IValue> input_tensors({a, b, c});
|
||||||
torch::jit::StaticModule smod(mod);
|
torch::jit::StaticModule smod(mod);
|
||||||
at::Tensor output_2 = smod(input_tensors)[0];
|
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||||
smod.runtime().check_for_memory_leak();
|
smod.runtime().check_for_memory_leak();
|
||||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
@ -773,9 +773,9 @@ TEST(StaticRuntime, TrivialModel) {
|
||||||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||||
|
|
||||||
// run static runtime
|
// run static runtime
|
||||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
std::vector<c10::IValue> input_tensors({a, b, c});
|
||||||
torch::jit::StaticModule smod(mod);
|
torch::jit::StaticModule smod(mod);
|
||||||
at::Tensor output_2 = smod(input_tensors)[0];
|
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||||
smod.runtime().check_for_memory_leak();
|
smod.runtime().check_for_memory_leak();
|
||||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
@ -789,9 +789,9 @@ TEST(StaticRuntime, LeakyReLU) {
|
||||||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||||
|
|
||||||
// run static runtime
|
// run static runtime
|
||||||
std::vector<at::Tensor> input_tensors({inputs});
|
std::vector<c10::IValue> input_tensors({inputs});
|
||||||
torch::jit::StaticModule smod(mod);
|
torch::jit::StaticModule smod(mod);
|
||||||
at::Tensor output_2 = smod(input_tensors)[0];
|
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||||
smod.runtime().check_for_memory_leak();
|
smod.runtime().check_for_memory_leak();
|
||||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
@ -813,8 +813,10 @@ TEST(StaticRuntime, DeepWide) {
|
||||||
auto output_1 = getTensor(mod.forward(inputs));
|
auto output_1 = getTensor(mod.forward(inputs));
|
||||||
|
|
||||||
// run static runtime
|
// run static runtime
|
||||||
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
|
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
|
||||||
at::Tensor output_2 = smod(input_tensors)[0];
|
auto outputs = smod(input_tensors, {}).toTuple()->elements();
|
||||||
|
ASSERT_TRUE(outputs.size() > 0);
|
||||||
|
at::Tensor output_2 = outputs[0].toTensor();
|
||||||
smod.runtime().check_for_memory_leak();
|
smod.runtime().check_for_memory_leak();
|
||||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||||
}
|
}
|
||||||
|
|
@ -947,9 +949,11 @@ TEST(StaticRuntime, CleanUpMemory) {
|
||||||
auto output_1 = getTensor(mod.forward(inputs));
|
auto output_1 = getTensor(mod.forward(inputs));
|
||||||
|
|
||||||
// run static runtime
|
// run static runtime
|
||||||
std::vector<at::Tensor> input_tensors(
|
std::vector<c10::IValue> input_tensors(
|
||||||
{ad_emb_packed, user_emb, wide});
|
{ad_emb_packed, user_emb, wide});
|
||||||
at::Tensor output_2 = runtime(input_tensors)[0];
|
auto outputs = runtime(input_tensors, {}).toTuple()->elements();
|
||||||
|
ASSERT_TRUE(outputs.size() > 0);
|
||||||
|
auto output_2 = outputs[0].toTensor();
|
||||||
runtime.check_for_memory_leak();
|
runtime.check_for_memory_leak();
|
||||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||||
if (manage_output_tensors) {
|
if (manage_output_tensors) {
|
||||||
|
|
@ -1053,9 +1057,9 @@ TEST(StaticRuntime, ManageOutputTensorsWithDeallocateOutputTensors) {
|
||||||
torch::randn({batch_size, 1, embedding_size});
|
torch::randn({batch_size, 1, embedding_size});
|
||||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||||
auto wide = torch::randn({batch_size, num_features});
|
auto wide = torch::randn({batch_size, num_features});
|
||||||
std::vector<at::Tensor> input_tensors(
|
std::vector<c10::IValue> input_tensors(
|
||||||
{ad_emb_packed, user_emb, wide});
|
{ad_emb_packed, user_emb, wide});
|
||||||
runtime(input_tensors)[0];
|
runtime(input_tensors, {});
|
||||||
runtime.check_for_memory_leak();
|
runtime.check_for_memory_leak();
|
||||||
runtime.deallocateOutputTensors();
|
runtime.deallocateOutputTensors();
|
||||||
runtime.checkOutputTensorMemoryLeaks();
|
runtime.checkOutputTensorMemoryLeaks();
|
||||||
|
|
@ -1079,21 +1083,21 @@ TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
|
||||||
torch::randn({batch_size, 1, embedding_size});
|
torch::randn({batch_size, 1, embedding_size});
|
||||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||||
auto wide = torch::randn({batch_size, num_features});
|
auto wide = torch::randn({batch_size, num_features});
|
||||||
std::vector<at::Tensor> input_tensors(
|
std::vector<c10::IValue> input_tensors(
|
||||||
{ad_emb_packed, user_emb, wide});
|
{ad_emb_packed, user_emb, wide});
|
||||||
// Profile run.
|
// Profile run.
|
||||||
runtime(input_tensors)[0];
|
runtime(input_tensors, {});
|
||||||
runtime.deallocateOutputTensors();
|
runtime.deallocateOutputTensors();
|
||||||
// Run again to allocate output Tensors without deallocating them.
|
// Run again to allocate output Tensors without deallocating them.
|
||||||
runtime(input_tensors)[0];
|
runtime(input_tensors, {});
|
||||||
// Memory leak checking fails.
|
// Memory leak checking fails.
|
||||||
EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
|
EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
|
||||||
// Calling the runtime without deallocation fails too.
|
// Calling the runtime without deallocation fails too.
|
||||||
EXPECT_THROW(runtime(input_tensors)[0], std::exception);
|
EXPECT_THROW(runtime(input_tensors, {}), std::exception);
|
||||||
// After deallocation, everything works fine.
|
// After deallocation, everything works fine.
|
||||||
runtime.deallocateOutputTensors();
|
runtime.deallocateOutputTensors();
|
||||||
runtime.checkOutputTensorMemoryLeaks();
|
runtime.checkOutputTensorMemoryLeaks();
|
||||||
runtime(input_tensors)[0];
|
runtime(input_tensors, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(StaticRuntime, FusionPass) {
|
TEST(StaticRuntime, FusionPass) {
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@
|
||||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||||
|
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||||
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
||||||
#include <torch/csrc/jit/serialization/import.h>
|
#include <torch/csrc/jit/serialization/import.h>
|
||||||
|
|
@ -58,6 +59,8 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/ThreadLocalDebugInfo.h>
|
#include <c10/util/ThreadLocalDebugInfo.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||||
|
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -67,6 +70,7 @@
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -1803,6 +1807,46 @@ TEST(LoopPeelerTest, SimpleNestedLoops2) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(JitTracing, Basic) {
|
||||||
|
constexpr int batch_size = 4;
|
||||||
|
constexpr int input_size = 256;
|
||||||
|
|
||||||
|
int hidden_size = 2 * input_size;
|
||||||
|
|
||||||
|
auto input = at::randn({batch_size, input_size}, at::kCPU);
|
||||||
|
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
|
||||||
|
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
|
||||||
|
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
|
||||||
|
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
|
||||||
|
|
||||||
|
auto graph = build_lstm();
|
||||||
|
auto stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||||
|
auto traced = TraceGraph(graph, stack);
|
||||||
|
Tensor prof_out;
|
||||||
|
pop(stack, prof_out);
|
||||||
|
|
||||||
|
{
|
||||||
|
stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||||
|
Code cd(traced, "traced");
|
||||||
|
InterpreterState is{cd};
|
||||||
|
is.run(stack);
|
||||||
|
Tensor traced_out;
|
||||||
|
pop(stack, traced_out);
|
||||||
|
torch::allclose(prof_out, traced_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||||
|
Code cd(graph, "graph");
|
||||||
|
InterpreterState is{cd};
|
||||||
|
is.run(stack);
|
||||||
|
Tensor scripted_out;
|
||||||
|
pop(stack, scripted_out);
|
||||||
|
torch::allclose(prof_out, scripted_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||||
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
|
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
|
||||||
static const auto basic_example = R"JIT(
|
static const auto basic_example = R"JIT(
|
||||||
def basic(x, y):
|
def basic(x, y):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
@ -1289,6 +1291,101 @@ class TestFreezing(JitTestCase):
|
||||||
FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
|
FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
|
||||||
FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
|
FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
|
||||||
|
|
||||||
|
def test_freeze_module_with_user_preserved_attribute_on_submodule(self):
|
||||||
|
class SubModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SubModule, self).__init__()
|
||||||
|
self.a = 1
|
||||||
|
self.b = 2
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.a + self.b
|
||||||
|
|
||||||
|
class Module(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Module, self).__init__()
|
||||||
|
self.sub1 = SubModule()
|
||||||
|
self.sub2 = SubModule()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.sub1() + self.sub2()
|
||||||
|
|
||||||
|
m = torch.jit.script(Module())
|
||||||
|
m.eval()
|
||||||
|
m = torch.jit.freeze(m, preserved_attrs=['sub1.a', 'sub2.a'])
|
||||||
|
fm = m._c
|
||||||
|
|
||||||
|
self.assertTrue(fm.hasattr('sub1'))
|
||||||
|
self.assertTrue(fm.sub1.hasattr('a'))
|
||||||
|
self.assertFalse(fm.sub1.hasattr('b'))
|
||||||
|
self.assertTrue(fm.hasattr('sub2'))
|
||||||
|
self.assertTrue(fm.sub2.hasattr('a'))
|
||||||
|
self.assertFalse(fm.sub2.hasattr('b'))
|
||||||
|
self.assertEqual(m(), 6)
|
||||||
|
m.sub1.a += 1
|
||||||
|
self.assertEqual(m(), 7)
|
||||||
|
|
||||||
|
def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self):
|
||||||
|
class SubModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SubModule, self).__init__()
|
||||||
|
self.a = 1
|
||||||
|
self.b = 2
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.a + self.b
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def method_a(self):
|
||||||
|
return 42
|
||||||
|
|
||||||
|
class Module(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Module, self).__init__()
|
||||||
|
self.sub = SubModule()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
m = torch.jit.script(Module())
|
||||||
|
m.eval()
|
||||||
|
fm = torch.jit.freeze(m, preserved_attrs=['sub.a', 'sub.method_a'])._c
|
||||||
|
|
||||||
|
self.assertTrue(fm.hasattr('sub'))
|
||||||
|
self.assertTrue(fm.sub.hasattr('a'))
|
||||||
|
self.assertFalse(fm.sub.hasattr('b'))
|
||||||
|
self.assertTrue(fm.sub._has_method('method_a'))
|
||||||
|
|
||||||
|
def test_freeze_module_with_user_preserved_method_on_submodule(self):
|
||||||
|
class SubModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SubModule, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.method_a(x) + self.method_b(x)
|
||||||
|
|
||||||
|
def method_a(self, x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
def method_b(self, x):
|
||||||
|
return x + x
|
||||||
|
|
||||||
|
class Module(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Module, self).__init__()
|
||||||
|
self.sub = SubModule()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.sub(x)
|
||||||
|
|
||||||
|
m = torch.jit.script(Module())
|
||||||
|
m.eval()
|
||||||
|
fm = torch.jit.freeze(m, preserved_attrs=['sub.method_a'])._c
|
||||||
|
|
||||||
|
self.assertTrue(fm.hasattr('sub'))
|
||||||
|
self.assertTrue(fm.sub._has_method('method_a'))
|
||||||
|
self.assertFalse(fm.sub._has_method('method_b'))
|
||||||
|
|
||||||
@skipIfNoFBGEMM
|
@skipIfNoFBGEMM
|
||||||
def test_module_with_shared_type_instances(self):
|
def test_module_with_shared_type_instances(self):
|
||||||
class Child(nn.Module):
|
class Child(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
import torch
|
import torch
|
||||||
import torch._C
|
import torch._C
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
# TODO: enable linting check for this file
|
# TODO: enable linting check for this file
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["module: onnx"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: mobile"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._C
|
import torch._C
|
||||||
import torch.backends.xnnpack
|
import torch.backends.xnnpack
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
|
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.jit_utils import JitTestCase
|
from torch.testing._internal.jit_utils import JitTestCase
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from itertools import product as product
|
from itertools import product as product
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple, Optional
|
||||||
import io
|
import io
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||||
import operator
|
import operator
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import io
|
import io
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import io
|
import io
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1481,6 +1481,7 @@ class TestNormalizeOperators(JitTestCase):
|
||||||
"index_put",
|
"index_put",
|
||||||
"nn.functional.conv2d",
|
"nn.functional.conv2d",
|
||||||
"nn.functional.dropout",
|
"nn.functional.dropout",
|
||||||
|
"nn.functional.embedding", # Implemented with a lambda
|
||||||
"polygamma",
|
"polygamma",
|
||||||
"special.polygamma",
|
"special.polygamma",
|
||||||
"repeat",
|
"repeat",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# This is how we include tests located in test/jit/...
|
# This is how we include tests located in test/jit/...
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit_executor=legacy")
|
sys.argv.append("--jit_executor=legacy")
|
||||||
from test_jit_fuser import * # noqa: F403
|
from test_jit_fuser import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit_executor=legacy")
|
sys.argv.append("--jit_executor=legacy")
|
||||||
from test_jit import * # noqa: F403
|
from test_jit import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit_executor=profiling")
|
sys.argv.append("--jit_executor=profiling")
|
||||||
from test_jit import * # noqa: F403
|
from test_jit import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.argv.append("--jit_executor=simple")
|
sys.argv.append("--jit_executor=simple")
|
||||||
from test_jit import * # noqa: F403
|
from test_jit import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
# Owner(s): ["oncall: jit"]
|
||||||
|
|
||||||
from test_jit import JitTestCase
|
from test_jit import JitTestCase
|
||||||
from torch.testing._internal.common_utils import run_tests
|
from torch.testing._internal.common_utils import run_tests
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13982,6 +13982,32 @@ class TestNNDeviceType(NNTestCase):
|
||||||
self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
|
self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
|
||||||
self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))
|
self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))
|
||||||
|
|
||||||
|
def test_conv_empty_channel(self, device):
|
||||||
|
in_channels = 0
|
||||||
|
mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2).to(device)
|
||||||
|
inp = torch.randn(2, 0, 15, device=device)
|
||||||
|
self._test_module_empty_input(mod, inp, check_size=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||||
|
inp = torch.randn(2, 1, 0, device=device)
|
||||||
|
mod(inp)
|
||||||
|
|
||||||
|
mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2).to(device)
|
||||||
|
inp = torch.randn(2, 0, 50, 100, device=device)
|
||||||
|
self._test_module_empty_input(mod, inp, check_size=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||||
|
inp = torch.randn(2, 1, 40, 0, device=device)
|
||||||
|
mod(inp)
|
||||||
|
|
||||||
|
mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2).to(device)
|
||||||
|
inp = torch.randn(2, 0, 50, 20, 40, device=device)
|
||||||
|
self._test_module_empty_input(mod, inp, check_size=False)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||||
|
inp = torch.randn(2, 1, 50, 0, 40, device=device)
|
||||||
|
mod(inp)
|
||||||
|
|
||||||
def test_group_conv_empty(self, device):
|
def test_group_conv_empty(self, device):
|
||||||
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
|
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
|
||||||
inp = torch.randn(0, 4, 4, 4, device=device)
|
inp = torch.randn(0, 4, 4, 4, device=device)
|
||||||
|
|
|
||||||
|
|
@ -247,8 +247,8 @@ class TestOptim(TestCase):
|
||||||
|
|
||||||
def _test_complex_optimizer(self, optimizer_constructor):
|
def _test_complex_optimizer(self, optimizer_constructor):
|
||||||
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
||||||
complex_opt = optimizer_constructor(complex_param)
|
|
||||||
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
||||||
|
complex_opt = optimizer_constructor(complex_param)
|
||||||
real_opt = optimizer_constructor(real_param)
|
real_opt = optimizer_constructor(real_param)
|
||||||
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
|
|
@ -652,11 +652,6 @@ class TestOptim(TestCase):
|
||||||
[param], lr=1e-1, initial_accumulator_value=0.1
|
[param], lr=1e-1, initial_accumulator_value=0.1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self._test_complex_optimizer(
|
|
||||||
lambda param: optimizer(
|
|
||||||
[param], lr=1e-1, initial_accumulator_value=0.1, weight_decay=1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_adamax(self):
|
def test_adamax(self):
|
||||||
for optimizer in [optim.Adamax, optim_mt.Adamax]:
|
for optimizer in [optim.Adamax, optim_mt.Adamax]:
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class StaticModule:
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
return self.static_module(args)
|
return self.static_module(args, {})
|
||||||
else:
|
else:
|
||||||
return self.static_module(args, kwargs)
|
return self.static_module(args, kwargs)
|
||||||
|
|
||||||
|
|
@ -227,20 +227,20 @@ class TestStaticModule(TestCase):
|
||||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||||
ref_bot = bot_l(bot_inp)
|
ref_bot = bot_l(bot_inp)
|
||||||
acc_bot = bot_l_acc(bot_inp)[0]
|
acc_bot = bot_l_acc(bot_inp)
|
||||||
torch.testing.assert_close(acc_bot, ref_bot)
|
torch.testing.assert_close(acc_bot, ref_bot)
|
||||||
ref_top = top_l(top_inp)
|
ref_top = top_l(top_inp)
|
||||||
acc_top = top_l_acc(top_inp)[0]
|
acc_top = top_l_acc(top_inp)
|
||||||
torch.testing.assert_close(acc_top, ref_top)
|
torch.testing.assert_close(acc_top, ref_top)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||||
ref_bot = bot_l(bot_inp)
|
ref_bot = bot_l(bot_inp)
|
||||||
acc_bot = bot_l_acc(bot_inp)[0]
|
acc_bot = bot_l_acc(bot_inp)
|
||||||
torch.testing.assert_close(acc_bot, ref_bot)
|
torch.testing.assert_close(acc_bot, ref_bot)
|
||||||
ref_top = top_l(top_inp)
|
ref_top = top_l(top_inp)
|
||||||
acc_top = top_l_acc(top_inp)[0]
|
acc_top = top_l_acc(top_inp)
|
||||||
torch.testing.assert_close(acc_top, ref_top)
|
torch.testing.assert_close(acc_top, ref_top)
|
||||||
|
|
||||||
def test_trivial_graph(self):
|
def test_trivial_graph(self):
|
||||||
|
|
@ -248,7 +248,7 @@ class TestStaticModule(TestCase):
|
||||||
tg = torch.jit.script(trivial_graph)
|
tg = torch.jit.script(trivial_graph)
|
||||||
o_ref = tg(s, s, s)
|
o_ref = tg(s, s, s)
|
||||||
tg_a = StaticModule(tg)
|
tg_a = StaticModule(tg)
|
||||||
o_test = tg_a(s, s, s)[0]
|
o_test = tg_a(s, s, s)
|
||||||
torch.testing.assert_close(o_ref, o_test)
|
torch.testing.assert_close(o_ref, o_test)
|
||||||
|
|
||||||
def test_leaky_relu(self):
|
def test_leaky_relu(self):
|
||||||
|
|
@ -256,7 +256,7 @@ class TestStaticModule(TestCase):
|
||||||
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
||||||
o_ref = tg(s)
|
o_ref = tg(s)
|
||||||
tg_a = StaticModule(tg)
|
tg_a = StaticModule(tg)
|
||||||
o_test = tg_a(s)[0]
|
o_test = tg_a(s)
|
||||||
torch.testing.assert_close(o_ref, o_test)
|
torch.testing.assert_close(o_ref, o_test)
|
||||||
|
|
||||||
def test_attr(self):
|
def test_attr(self):
|
||||||
|
|
@ -292,7 +292,7 @@ class TestStaticModule(TestCase):
|
||||||
|
|
||||||
ms = torch.jit.script(m)
|
ms = torch.jit.script(m)
|
||||||
sm = StaticModule(ms)
|
sm = StaticModule(ms)
|
||||||
output_sm = sm(input)[0]
|
output_sm = sm(input)
|
||||||
torch.testing.assert_close(output_s, output_sm)
|
torch.testing.assert_close(output_s, output_sm)
|
||||||
sm.benchmark([input], {}, 2, 2)
|
sm.benchmark([input], {}, 2, 2)
|
||||||
sm.benchmark_individual_ops([input], {}, 2, 2)
|
sm.benchmark_individual_ops([input], {}, 2, 2)
|
||||||
|
|
|
||||||
|
|
@ -285,6 +285,7 @@ core_sources_full_mobile_no_backend_interface = [
|
||||||
"torch/csrc/jit/runtime/script_profile.cpp",
|
"torch/csrc/jit/runtime/script_profile.cpp",
|
||||||
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
||||||
"torch/csrc/jit/runtime/symbolic_shape_registry.cpp",
|
"torch/csrc/jit/runtime/symbolic_shape_registry.cpp",
|
||||||
|
"torch/csrc/jit/runtime/jit_trace.cpp",
|
||||||
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
|
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
|
||||||
"torch/csrc/jit/serialization/import.cpp",
|
"torch/csrc/jit/serialization/import.cpp",
|
||||||
"torch/csrc/jit/serialization/import_export_helpers.cpp",
|
"torch/csrc/jit/serialization/import_export_helpers.cpp",
|
||||||
|
|
|
||||||
|
|
@ -639,13 +639,13 @@ def convert_fx(
|
||||||
|
|
||||||
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
||||||
|
|
||||||
* `qconfig_dict`: qconfig_dict with the either
|
* `qconfig_dict`: qconfig_dict with either same keys as what is passed to
|
||||||
a) same keys as what is passed to the qconfig_dict in prepare_fx API, with same values or `None`.
|
the qconfig_dict in `prepare_fx` API, with same values or `None`, or
|
||||||
b) additional keys with values set to `None`
|
additional keys with values set to `None`
|
||||||
For each entry whose value is set to None, we skip quantizing that entry in the model.
|
|
||||||
Example:
|
|
||||||
qconfig_dict = {
|
|
||||||
|
|
||||||
|
For each entry whose value is set to None, we skip quantizing that entry in the model::
|
||||||
|
|
||||||
|
qconfig_dict = {
|
||||||
# used for object_type, skip quantizing torch.nn.functional.add
|
# used for object_type, skip quantizing torch.nn.functional.add
|
||||||
"object_type": [
|
"object_type": [
|
||||||
(torch.nn.functional.add, None),
|
(torch.nn.functional.add, None),
|
||||||
|
|
@ -660,6 +660,7 @@ def convert_fx(
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A quantized model (GraphModule)
|
A quantized model (GraphModule)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,10 @@ namespace deploy {
|
||||||
static const size_t NUM_FROZEN_PY_BUILTIN_MODULES = 6;
|
static const size_t NUM_FROZEN_PY_BUILTIN_MODULES = 6;
|
||||||
static const size_t NUM_FROZEN_PY_STDLIB_MODULES = 680;
|
static const size_t NUM_FROZEN_PY_STDLIB_MODULES = 680;
|
||||||
|
|
||||||
extern "C" struct _frozen _PyImport_FrozenModules[];
|
|
||||||
extern "C" struct _frozen _PyImport_FrozenModules_torch[];
|
extern "C" struct _frozen _PyImport_FrozenModules_torch[];
|
||||||
extern "C" PyObject* initModule(void);
|
extern "C" PyObject* initModule(void);
|
||||||
|
|
||||||
REGISTER_TORCH_DEPLOY_BUILTIN(cpython_internal, PyImport_FrozenModules);
|
REGISTER_TORCH_DEPLOY_BUILTIN(cpython_internal, PyImport_FrozenModules);
|
||||||
REGISTER_TORCH_DEPLOY_BUILTIN(frozenpython, _PyImport_FrozenModules);
|
|
||||||
REGISTER_TORCH_DEPLOY_BUILTIN(
|
REGISTER_TORCH_DEPLOY_BUILTIN(
|
||||||
frozentorch,
|
frozentorch,
|
||||||
_PyImport_FrozenModules_torch,
|
_PyImport_FrozenModules_torch,
|
||||||
|
|
|
||||||
|
|
@ -34,78 +34,6 @@ using namespace py::literals;
|
||||||
#define PYOBJ_ASSERT(obj) assert(NULL != obj);
|
#define PYOBJ_ASSERT(obj) assert(NULL != obj);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define FOREACH_LIBRARY(_) \
|
|
||||||
_(array) \
|
|
||||||
_(_asyncio) \
|
|
||||||
_(audioop) \
|
|
||||||
_(binascii) \
|
|
||||||
_(_bisect) \
|
|
||||||
_(_blake2) \
|
|
||||||
_(_bz2) \
|
|
||||||
_(cmath) \
|
|
||||||
_(_codecs_cn) \
|
|
||||||
_(_codecs_hk) \
|
|
||||||
_(_codecs_iso2022) \
|
|
||||||
_(_codecs_jp) \
|
|
||||||
_(_codecs_kr) \
|
|
||||||
_(_codecs_tw) \
|
|
||||||
_(_contextvars) \
|
|
||||||
_(_crypt) \
|
|
||||||
_(_csv) \
|
|
||||||
_(_ctypes) \
|
|
||||||
_(_ctypes_test) \
|
|
||||||
_(_curses) \
|
|
||||||
_(_curses_panel) \
|
|
||||||
_(_datetime) \
|
|
||||||
_(_decimal) \
|
|
||||||
_(_elementtree) \
|
|
||||||
_(fcntl) \
|
|
||||||
_(grp) \
|
|
||||||
_(_hashlib) \
|
|
||||||
_(_heapq) \
|
|
||||||
_(_json) \
|
|
||||||
_(_lsprof) \
|
|
||||||
_(_lzma) \
|
|
||||||
_(math) \
|
|
||||||
_(_md5) \
|
|
||||||
_(mmap) \
|
|
||||||
_(_multibytecodec) \
|
|
||||||
_(_multiprocessing) \
|
|
||||||
_(nis) \
|
|
||||||
_(_opcode) \
|
|
||||||
_(ossaudiodev) \
|
|
||||||
_(parser) \
|
|
||||||
_(_pickle) \
|
|
||||||
_(_posixsubprocess) \
|
|
||||||
_(pyexpat) \
|
|
||||||
_(_queue) \
|
|
||||||
_(_random) \
|
|
||||||
_(readline) \
|
|
||||||
_(resource) \
|
|
||||||
_(select) \
|
|
||||||
_(_sha1) \
|
|
||||||
_(_sha256) \
|
|
||||||
_(_sha3) \
|
|
||||||
_(_sha512) \
|
|
||||||
_(_socket) \
|
|
||||||
_(spwd) \
|
|
||||||
_(_ssl) \
|
|
||||||
_(_struct) \
|
|
||||||
_(syslog) \
|
|
||||||
_(termios) \
|
|
||||||
_(_testbuffer) \
|
|
||||||
_(_testcapi) \
|
|
||||||
_(_testimportmultiple) \
|
|
||||||
_(_testmultiphase) \
|
|
||||||
_(unicodedata) \
|
|
||||||
_(xxlimited) \
|
|
||||||
_(_xxtestfuzz) \
|
|
||||||
_(zlib)
|
|
||||||
|
|
||||||
#define DECLARE_LIBRARY_INIT(name) extern "C" PyObject* PyInit_##name(void);
|
|
||||||
FOREACH_LIBRARY(DECLARE_LIBRARY_INIT)
|
|
||||||
#undef DECLARE_LIBRARY_INIT
|
|
||||||
|
|
||||||
const char* start = R"PYTHON(
|
const char* start = R"PYTHON(
|
||||||
import _ssl # must come before _hashlib otherwise ssl's locks will be set to a Python that might no longer exist...
|
import _ssl # must come before _hashlib otherwise ssl's locks will be set to a Python that might no longer exist...
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -221,10 +149,6 @@ struct InitLockAcquire {
|
||||||
struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
|
struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
|
||||||
: public torch::deploy::InterpreterImpl {
|
: public torch::deploy::InterpreterImpl {
|
||||||
ConcreteInterpreterImpl() {
|
ConcreteInterpreterImpl() {
|
||||||
#define APPEND_INIT(name) PyImport_AppendInittab(#name, PyInit_##name);
|
|
||||||
FOREACH_LIBRARY(APPEND_INIT)
|
|
||||||
#undef APPEND_INIT
|
|
||||||
|
|
||||||
BuiltinRegistry::runPreInitialization();
|
BuiltinRegistry::runPreInitialization();
|
||||||
|
|
||||||
PyPreConfig preconfig;
|
PyPreConfig preconfig;
|
||||||
|
|
|
||||||
82
torch/csrc/deploy/interpreter/register_frozenpython.cpp
Normal file
82
torch/csrc/deploy/interpreter/register_frozenpython.cpp
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
#include <Python.h>
|
||||||
|
#include <torch/csrc/deploy/interpreter/builtin_registry.h>
|
||||||
|
|
||||||
|
#define FOREACH_LIBRARY(_) \
|
||||||
|
_(array) \
|
||||||
|
_(_asyncio) \
|
||||||
|
_(audioop) \
|
||||||
|
_(binascii) \
|
||||||
|
_(_bisect) \
|
||||||
|
_(_blake2) \
|
||||||
|
_(_bz2) \
|
||||||
|
_(cmath) \
|
||||||
|
_(_codecs_cn) \
|
||||||
|
_(_codecs_hk) \
|
||||||
|
_(_codecs_iso2022) \
|
||||||
|
_(_codecs_jp) \
|
||||||
|
_(_codecs_kr) \
|
||||||
|
_(_codecs_tw) \
|
||||||
|
_(_contextvars) \
|
||||||
|
_(_crypt) \
|
||||||
|
_(_csv) \
|
||||||
|
_(_ctypes) \
|
||||||
|
_(_ctypes_test) \
|
||||||
|
_(_curses) \
|
||||||
|
_(_curses_panel) \
|
||||||
|
_(_datetime) \
|
||||||
|
_(_decimal) \
|
||||||
|
_(_elementtree) \
|
||||||
|
_(fcntl) \
|
||||||
|
_(grp) \
|
||||||
|
_(_hashlib) \
|
||||||
|
_(_heapq) \
|
||||||
|
_(_json) \
|
||||||
|
_(_lsprof) \
|
||||||
|
_(_lzma) \
|
||||||
|
_(math) \
|
||||||
|
_(_md5) \
|
||||||
|
_(mmap) \
|
||||||
|
_(_multibytecodec) \
|
||||||
|
_(_multiprocessing) \
|
||||||
|
_(nis) \
|
||||||
|
_(_opcode) \
|
||||||
|
_(ossaudiodev) \
|
||||||
|
_(parser) \
|
||||||
|
_(_pickle) \
|
||||||
|
_(_posixsubprocess) \
|
||||||
|
_(pyexpat) \
|
||||||
|
_(_queue) \
|
||||||
|
_(_random) \
|
||||||
|
_(readline) \
|
||||||
|
_(resource) \
|
||||||
|
_(select) \
|
||||||
|
_(_sha1) \
|
||||||
|
_(_sha256) \
|
||||||
|
_(_sha3) \
|
||||||
|
_(_sha512) \
|
||||||
|
_(_socket) \
|
||||||
|
_(spwd) \
|
||||||
|
_(_ssl) \
|
||||||
|
_(_struct) \
|
||||||
|
_(syslog) \
|
||||||
|
_(termios) \
|
||||||
|
_(_testbuffer) \
|
||||||
|
_(_testcapi) \
|
||||||
|
_(_testimportmultiple) \
|
||||||
|
_(_testmultiphase) \
|
||||||
|
_(unicodedata) \
|
||||||
|
_(xxlimited) \
|
||||||
|
_(_xxtestfuzz) \
|
||||||
|
_(zlib)
|
||||||
|
|
||||||
|
#define DECLARE_LIBRARY_INIT(name) extern "C" PyObject* PyInit_##name(void);
|
||||||
|
FOREACH_LIBRARY(DECLARE_LIBRARY_INIT)
|
||||||
|
#undef DECLARE_LIBRARY_INIT
|
||||||
|
|
||||||
|
extern "C" struct _frozen _PyImport_FrozenModules[];
|
||||||
|
|
||||||
|
#define STD_LIBARY_PARMS(name) , #name, PyInit_##name
|
||||||
|
REGISTER_TORCH_DEPLOY_BUILTIN(
|
||||||
|
frozenpython,
|
||||||
|
_PyImport_FrozenModules FOREACH_LIBRARY(STD_LIBARY_PARMS));
|
||||||
|
#undef STD_LIBARY_PARMS
|
||||||
29
torch/csrc/jit/mobile/code.h
Normal file
29
torch/csrc/jit/mobile/code.h
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/operator_name.h>
|
||||||
|
#include <torch/csrc/jit/runtime/instruction.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace mobile {
|
||||||
|
|
||||||
|
using Stack = std::vector<c10::IValue>;
|
||||||
|
using DebugHandle = int64_t;
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||||
|
struct Code {
|
||||||
|
std::vector<Instruction> instructions_;
|
||||||
|
std::vector<DebugHandle> debug_handles_;
|
||||||
|
std::vector<c10::OperatorName> op_names_;
|
||||||
|
std::vector<std::function<void(Stack&)>> operators_;
|
||||||
|
std::vector<c10::IValue> constants_;
|
||||||
|
std::vector<c10::TypePtr> types_;
|
||||||
|
size_t register_size_; // Aggregated output size.
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mobile
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
53
torch/csrc/jit/mobile/frame.h
Normal file
53
torch/csrc/jit/mobile/frame.h
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <torch/csrc/jit/mobile/code.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
namespace mobile {
|
||||||
|
|
||||||
|
class Frame {
|
||||||
|
public:
|
||||||
|
explicit Frame(const Code& code) : code_(code) {}
|
||||||
|
const Code& getCode() const {
|
||||||
|
return code_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void step() {
|
||||||
|
pc_++;
|
||||||
|
}
|
||||||
|
|
||||||
|
void jump(size_t n) {
|
||||||
|
pc_ += n;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getPC() const {
|
||||||
|
return pc_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Instruction& getInstruction() const {
|
||||||
|
return code_.instructions_.at(pc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<int64_t> getDebugHandle() const {
|
||||||
|
return getDebugHandle(pc_);
|
||||||
|
}
|
||||||
|
|
||||||
|
c10::optional<int64_t> getDebugHandle(size_t pc) const {
|
||||||
|
if (pc >= code_.debug_handles_.size()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return code_.debug_handles_[pc];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const Code& code_;
|
||||||
|
size_t pc_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mobile
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -167,7 +167,7 @@ bool Function::run(Stack& stack) const {
|
||||||
schema->checkAndNormalizeInputs(
|
schema->checkAndNormalizeInputs(
|
||||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||||
}
|
}
|
||||||
InterpreterState interp_state(code_);
|
InterpreterState interp_state(*code_);
|
||||||
return interp_state.run(stack);
|
return interp_state.run(stack);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,8 +181,7 @@ const std::shared_ptr<Code> Function::get_code() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t Function::getExceptionDebugHandle() const {
|
int64_t Function::getExceptionDebugHandle() const {
|
||||||
size_t pc = getInterpretersExceptionPC();
|
return getInterpretersExceptionDebugHandle();
|
||||||
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
|
|
|
||||||
|
|
@ -8,23 +8,22 @@
|
||||||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||||
|
|
||||||
#include <ATen/record_function.h>
|
#include <ATen/record_function.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
#include <torch/csrc/jit/mobile/observer.h>
|
|
||||||
|
|
||||||
#include <torch/csrc/jit/backends/backend_exception.h>
|
#include <torch/csrc/jit/backends/backend_exception.h>
|
||||||
|
#include <torch/csrc/jit/mobile/observer.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
char const* toString(OpCode op);
|
char const* toString(OpCode op);
|
||||||
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||||
namespace mobile {
|
namespace mobile {
|
||||||
InterpreterState::InterpreterState(std::shared_ptr<Code> code)
|
InterpreterState::InterpreterState(const Code& code) {
|
||||||
: code_(std::move(code)) {
|
enterFrame(code);
|
||||||
registers_.resize(code_->register_size_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static thread_local int64_t exception_pc_{-1};
|
static thread_local DebugHandle exception_debug_handle_{-1};
|
||||||
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
||||||
auto userObj = c10::ivalue::Object::create(
|
auto userObj = c10::ivalue::Object::create(
|
||||||
c10::StrongTypePtr(type->compilation_unit(), type),
|
c10::StrongTypePtr(type->compilation_unit(), type),
|
||||||
|
|
@ -46,21 +45,42 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
||||||
|
|
||||||
using namespace at;
|
using namespace at;
|
||||||
|
|
||||||
int64_t getInterpretersExceptionPC() {
|
int64_t getInterpretersExceptionDebugHandle() {
|
||||||
return exception_pc_;
|
return exception_debug_handle_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void InterpreterState::enterFrame(const Code& code) {
|
||||||
|
frames_.emplace_back(code);
|
||||||
|
registers_.resize(registers_.size() + code.register_size_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InterpreterState::leaveFrame() {
|
||||||
|
registers_.resize(
|
||||||
|
registers_.size() - frames_.back().getCode().register_size_);
|
||||||
|
frames_.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InterpreterState::saveExceptionDebugHandle() {
|
||||||
|
const auto& frame = frames_.back();
|
||||||
|
if (auto handle = frame.getDebugHandle()) {
|
||||||
|
exception_debug_handle_ = *handle;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InterpreterState::run(Stack& stack) {
|
bool InterpreterState::run(Stack& stack) {
|
||||||
size_t pc = 0;
|
|
||||||
while (true) {
|
while (true) {
|
||||||
try {
|
try {
|
||||||
Instruction inst = code_->instructions_.at(pc);
|
auto& frame = frames_.back();
|
||||||
|
const auto& code = frame.getCode();
|
||||||
|
const auto pc = frame.getPC();
|
||||||
|
auto inst = frame.getInstruction();
|
||||||
// If no valid debug handle found then just log pc.
|
// If no valid debug handle found then just log pc.
|
||||||
// This is possible when we did not save debug handles
|
// This is possible when we did not save debug handles
|
||||||
DebugHandle debug_handle = pc >= code_->debug_handles_.size()
|
|
||||||
? pc
|
DebugHandle debug_handle = pc;
|
||||||
: code_->debug_handles_.at(pc);
|
if (auto handle = frame.getDebugHandle()) {
|
||||||
|
debug_handle = *handle;
|
||||||
|
}
|
||||||
|
|
||||||
// std::cout << "RUNNING " << pc << " "
|
// std::cout << "RUNNING " << pc << " "
|
||||||
// << code_->instructions_with_handles_[pc].instruction;
|
// << code_->instructions_with_handles_[pc].instruction;
|
||||||
|
|
@ -93,63 +113,63 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
}
|
}
|
||||||
|
|
||||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||||
code_->op_names_[inst.X].name, debug_handle, stack);
|
code.op_names_[inst.X].name, debug_handle, stack);
|
||||||
code_->operators_[inst.X](stack);
|
code.operators_[inst.X](stack);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case OPN: {
|
case OPN: {
|
||||||
stack.push_back(inst.N);
|
stack.push_back(inst.N);
|
||||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||||
code_->op_names_[inst.X].name, debug_handle, stack);
|
code.op_names_[inst.X].name, debug_handle, stack);
|
||||||
code_->operators_[inst.X](stack);
|
code.operators_[inst.X](stack);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case INTERFACE_CALL: {
|
case INTERFACE_CALL: {
|
||||||
torch::jit::Function& method =
|
torch::jit::Function& method =
|
||||||
peek(stack, 0, inst.N)
|
peek(stack, 0, inst.N)
|
||||||
.toObject()
|
.toObject()
|
||||||
->type()
|
->type()
|
||||||
->getMethod(code_->constants_[inst.X].toStringRef());
|
->getMethod(code.constants_[inst.X].toStringRef());
|
||||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||||
method.name(), debug_handle, stack);
|
method.name(), debug_handle, stack);
|
||||||
method.run(stack);
|
method.run(stack);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case LOAD:
|
case LOAD:
|
||||||
stack.emplace_back(reg(inst.X));
|
stack.emplace_back(reg(inst.X));
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case MOVE:
|
case MOVE:
|
||||||
stack.emplace_back(std::move(reg(inst.X)));
|
stack.emplace_back(std::move(reg(inst.X)));
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case STORE:
|
case STORE:
|
||||||
reg(inst.X) = pop(stack);
|
reg(inst.X) = pop(stack);
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case STOREN:
|
case STOREN:
|
||||||
for (size_t i = inst.N; i > 0; --i) {
|
for (size_t i = inst.N; i > 0; --i) {
|
||||||
reg(inst.X + i - 1) = pop(stack);
|
reg(inst.X + i - 1) = pop(stack);
|
||||||
}
|
}
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case DROP:
|
case DROP:
|
||||||
pop(stack);
|
pop(stack);
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case DROPR:
|
case DROPR:
|
||||||
reg(inst.X) = IValue();
|
reg(inst.X) = IValue();
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case LOADC:
|
case LOADC:
|
||||||
stack.emplace_back(code_->constants_[inst.X]);
|
stack.emplace_back(code.constants_[inst.X]);
|
||||||
++pc;
|
frame.step();
|
||||||
break;
|
break;
|
||||||
case GET_ATTR: {
|
case GET_ATTR: {
|
||||||
auto userObj = pop(stack).toObject();
|
auto userObj = pop(stack).toObject();
|
||||||
auto value = userObj->getSlot(inst.X);
|
auto value = userObj->getSlot(inst.X);
|
||||||
push(stack, std::move(value));
|
push(stack, std::move(value));
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case SET_ATTR: {
|
case SET_ATTR: {
|
||||||
auto v = pop(stack);
|
auto v = pop(stack);
|
||||||
|
|
@ -163,72 +183,76 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
|
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
|
||||||
}
|
}
|
||||||
userObj->setSlot(inst.X, std::move(v));
|
userObj->setSlot(inst.X, std::move(v));
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case JF:
|
case JF:
|
||||||
pc += (pop(stack).toBool()) ? 1 : inst.X;
|
frame.jump(pop(stack).toBool() ? 1 : inst.X);
|
||||||
break;
|
break;
|
||||||
case JMP:
|
case JMP:
|
||||||
pc += inst.X;
|
frame.jump(inst.X);
|
||||||
break;
|
break;
|
||||||
case LOOP: {
|
case LOOP: {
|
||||||
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
||||||
auto frame = stack.end() - (inst.N + 1);
|
auto sframe = stack.end() - (inst.N + 1);
|
||||||
int64_t trip_count = frame[0].toInt();
|
int64_t trip_count = sframe[0].toInt();
|
||||||
int64_t max_trip_count = frame[1].toInt();
|
int64_t max_trip_count = sframe[1].toInt();
|
||||||
bool cond = frame[2].toBool();
|
bool cond = sframe[2].toBool();
|
||||||
if (trip_count < max_trip_count && cond) {
|
if (trip_count < max_trip_count && cond) {
|
||||||
frame[2] = trip_count;
|
sframe[2] = trip_count;
|
||||||
frame[0] = trip_count + 1;
|
sframe[0] = trip_count + 1;
|
||||||
++pc;
|
frame.step();
|
||||||
} else {
|
} else {
|
||||||
size_t n_loop_carried = inst.N - 2;
|
size_t n_loop_carried = inst.N - 2;
|
||||||
for (const auto i : c10::irange(n_loop_carried)) {
|
for (const auto i : c10::irange(n_loop_carried)) {
|
||||||
frame[i] = std::move(frame[i + 3]);
|
sframe[i] = std::move(sframe[i + 3]);
|
||||||
}
|
}
|
||||||
drop(stack, 3); // iteration_count, max_iter, cond
|
drop(stack, 3); // iteration_count, max_iter, cond
|
||||||
pc += inst.X;
|
frame.jump(inst.X);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case RET:
|
case RET:
|
||||||
|
leaveFrame();
|
||||||
|
if (frames_.size() > 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
case LIST_CONSTRUCT: {
|
case LIST_CONSTRUCT: {
|
||||||
const auto& type = code_->types_[inst.X]->expectRef<at::ListType>();
|
const auto& type = code.types_[inst.X]->expectRef<at::ListType>();
|
||||||
listConstruct(stack, type, inst.N);
|
listConstruct(stack, type, inst.N);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case LIST_UNPACK: {
|
case LIST_UNPACK: {
|
||||||
listUnpack(stack, inst.X);
|
listUnpack(stack, inst.X);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case TUPLE_CONSTRUCT: {
|
case TUPLE_CONSTRUCT: {
|
||||||
tupleConstruct(stack, inst.X);
|
tupleConstruct(stack, inst.X);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case TUPLE_SLICE: {
|
case TUPLE_SLICE: {
|
||||||
tupleSlice(stack, inst.X, inst.X + inst.N);
|
tupleSlice(stack, inst.X, inst.X + inst.N);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case DICT_CONSTRUCT: {
|
case DICT_CONSTRUCT: {
|
||||||
const auto& type = code_->types_[inst.X]->expectRef<at::DictType>();
|
const auto& type = code.types_[inst.X]->expectRef<at::DictType>();
|
||||||
dictConstruct(stack, type, inst.N);
|
dictConstruct(stack, type, inst.N);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case NAMED_TUPLE_CONSTRUCT: {
|
case NAMED_TUPLE_CONSTRUCT: {
|
||||||
namedTupleConstruct(
|
namedTupleConstruct(
|
||||||
stack, code_->types_[inst.X]->expect<at::TupleType>(), inst.N);
|
stack, code.types_[inst.X]->expect<at::TupleType>(), inst.N);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case CREATE_OBJECT: {
|
case CREATE_OBJECT: {
|
||||||
auto type = code_->types_[inst.X]->expect<c10::ClassType>();
|
auto type = code.types_[inst.X]->expect<c10::ClassType>();
|
||||||
createObject(stack, type);
|
createObject(stack, type);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case ISINSTANCE: {
|
case ISINSTANCE: {
|
||||||
at::ArrayRef<TypePtr> types(
|
at::ArrayRef<TypePtr> types(
|
||||||
&(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N]));
|
&(code.types_[inst.X]), &(code.types_[inst.X + inst.N]));
|
||||||
isinstance(stack, types);
|
isinstance(stack, types);
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
case WARN: {
|
case WARN: {
|
||||||
drop(stack, 1);
|
drop(stack, 1);
|
||||||
|
|
@ -240,7 +264,7 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
const auto& sref = stack.back().toStringRef();
|
const auto& sref = stack.back().toStringRef();
|
||||||
TORCH_WARN(sref);
|
TORCH_WARN(sref);
|
||||||
stack.pop_back();
|
stack.pop_back();
|
||||||
++pc;
|
frame.step();
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
AT_ERROR(toString(inst.op), " is invalid.");
|
AT_ERROR(toString(inst.op), " is invalid.");
|
||||||
|
|
@ -251,15 +275,15 @@ bool InterpreterState::run(Stack& stack) {
|
||||||
}
|
}
|
||||||
// This exception must be caught first as it derived from c10::Error
|
// This exception must be caught first as it derived from c10::Error
|
||||||
} catch (c10::BackendRuntimeException& e) {
|
} catch (c10::BackendRuntimeException& e) {
|
||||||
exception_pc_ = pc;
|
saveExceptionDebugHandle();
|
||||||
TORCH_RETHROW(e);
|
TORCH_RETHROW(e);
|
||||||
} catch (c10::Error& error) {
|
} catch (c10::Error& error) {
|
||||||
// Reason for catching and rethrowing the error is so that we can
|
// Reason for catching and rethrowing the error is so that we can
|
||||||
// set the exception pc that is queried later
|
// set the exception pc that is queried later
|
||||||
exception_pc_ = pc;
|
saveExceptionDebugHandle();
|
||||||
TORCH_RETHROW(error);
|
TORCH_RETHROW(error);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
exception_pc_ = pc;
|
saveExceptionDebugHandle();
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
// for (auto val : stack) {
|
// for (auto val : stack) {
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,26 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/core/ivalue.h>
|
|
||||||
#include <ATen/core/operator_name.h>
|
|
||||||
#include <torch/csrc/jit/runtime/instruction.h>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include <torch/csrc/jit/mobile/code.h>
|
||||||
|
#include <torch/csrc/jit/mobile/frame.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
namespace mobile {
|
namespace mobile {
|
||||||
using Stack = std::vector<c10::IValue>;
|
|
||||||
using DebugHandle = int64_t;
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
||||||
struct Code {
|
|
||||||
// TODO: Combine instructions and debug handles vector
|
|
||||||
// into std::vector<<std::pair<Instruction, DebugHandle>>
|
|
||||||
std::vector<Instruction> instructions_;
|
|
||||||
std::vector<DebugHandle> debug_handles_;
|
|
||||||
std::vector<c10::OperatorName> op_names_;
|
|
||||||
std::vector<std::function<void(Stack&)>> operators_;
|
|
||||||
std::vector<c10::IValue> constants_;
|
|
||||||
std::vector<c10::TypePtr> types_;
|
|
||||||
size_t register_size_; // Aggregated output size.
|
|
||||||
};
|
|
||||||
|
|
||||||
struct InterpreterState {
|
struct InterpreterState {
|
||||||
TORCH_API explicit InterpreterState(std::shared_ptr<Code> code);
|
TORCH_API explicit InterpreterState(const Code& code);
|
||||||
TORCH_API bool run(Stack& stack);
|
TORCH_API bool run(Stack& stack);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Code> code_;
|
void enterFrame(const Code&);
|
||||||
|
void leaveFrame();
|
||||||
|
void saveExceptionDebugHandle();
|
||||||
|
|
||||||
c10::IValue& reg(size_t reg);
|
c10::IValue& reg(size_t reg);
|
||||||
std::vector<c10::IValue> registers_;
|
std::vector<c10::IValue> registers_;
|
||||||
|
std::vector<Frame> frames_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Interpreter executes instruction in a loop one by one
|
// Interpreter executes instruction in a loop one by one
|
||||||
|
|
@ -39,7 +30,7 @@ struct InterpreterState {
|
||||||
// Note that this is set only when exception occurs.
|
// Note that this is set only when exception occurs.
|
||||||
// since this is a thread local variable and setting it for
|
// since this is a thread local variable and setting it for
|
||||||
// every instruction will add overhead of thread local variable access.
|
// every instruction will add overhead of thread local variable access.
|
||||||
int64_t getInterpretersExceptionPC();
|
DebugHandle getInterpretersExceptionDebugHandle();
|
||||||
} // namespace mobile
|
} // namespace mobile
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ void _not(Stack& stack) {
|
||||||
void boolTensor(Stack& stack) {
|
void boolTensor(Stack& stack) {
|
||||||
at::Tensor a;
|
at::Tensor a;
|
||||||
pop(stack, a);
|
pop(stack, a);
|
||||||
push(stack, a.is_nonzero());
|
push(stack, at::native::is_nonzero(a));
|
||||||
}
|
}
|
||||||
|
|
||||||
void toList(Stack& stack) {
|
void toList(Stack& stack) {
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,16 @@ namespace jit {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
std::vector<std::string> splitName(const std::string& name) {
|
||||||
|
std::vector<std::string> result;
|
||||||
|
std::string sub_name;
|
||||||
|
std::istringstream name_stream(name);
|
||||||
|
while (std::getline(name_stream, sub_name, '.')) {
|
||||||
|
result.push_back(std::move(sub_name));
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
class AttributePropagator {
|
class AttributePropagator {
|
||||||
public:
|
public:
|
||||||
AttributePropagator(
|
AttributePropagator(
|
||||||
|
|
@ -27,27 +37,27 @@ class AttributePropagator {
|
||||||
: module_(module),
|
: module_(module),
|
||||||
freezeInterfaces_(freezeInterfaces),
|
freezeInterfaces_(freezeInterfaces),
|
||||||
preserveParameters_(preserveParameters) {
|
preserveParameters_(preserveParameters) {
|
||||||
// Currently only top level attributes and functions can be preserved
|
|
||||||
// explicitly.
|
|
||||||
auto checkName = [this](std::string& name) {
|
auto checkName = [this](std::string& name) {
|
||||||
if (module_.hasattr(name)) {
|
const auto resolved_name = resolveName(name);
|
||||||
auto attr = module_.attr(name);
|
|
||||||
|
|
||||||
|
if (resolved_name) {
|
||||||
|
const auto& parent_module = resolved_name->first;
|
||||||
|
const auto& attr_name = resolved_name->second;
|
||||||
|
if (parent_module.hasattr(attr_name)) {
|
||||||
|
auto value = parent_module.attr(attr_name);
|
||||||
// Freezing client wants to presever this submodule. When cleaning
|
// Freezing client wants to presever this submodule. When cleaning
|
||||||
// the frozen module, make sure it will be preserved entirely.
|
// the frozen module, make sure it will be preserved entirely.
|
||||||
if (attr.isModule()) {
|
if (value.isModule()) {
|
||||||
preservedSubModule_.insert(attr.toModule()._ivalue());
|
preservedSubModule_.insert(value.toModule()._ivalue());
|
||||||
|
}
|
||||||
|
insertMutableAttr(attr_name, value, parent_module._ivalue());
|
||||||
|
} else {
|
||||||
|
auto fn = parent_module.get_method(attr_name);
|
||||||
|
preservedMethods_.insert(&fn.function());
|
||||||
}
|
}
|
||||||
insertMutableAttr(name, attr, module_._ivalue());
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& fn : module_.type()->methods()) {
|
|
||||||
if (fn->name() == name) {
|
|
||||||
preservedMethods_.insert(fn);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -119,6 +129,57 @@ class AttributePropagator {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
using ResolvedName = std::pair<Module, std::string>;
|
||||||
|
|
||||||
|
// Try to resolve qualified names (submodule1.submodule2.foo). If
|
||||||
|
// the qualified name exists in the root module, return the unqualified
|
||||||
|
// attribute/function name and the parent module. Else, return nullopt.
|
||||||
|
// Examples:
|
||||||
|
// submodule1.submodule2.foo -> {submodule2, "foo"}
|
||||||
|
// submodule1.non_existent_module.foo -> nullopt
|
||||||
|
c10::optional<ResolvedName> resolveName(const std::string& name) {
|
||||||
|
auto sub_names = splitName(name);
|
||||||
|
if (sub_names.empty()) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
auto& attr_name = sub_names.back();
|
||||||
|
auto cur_module = module_;
|
||||||
|
std::vector<ResolvedName> attr_infos;
|
||||||
|
attr_infos.reserve(sub_names.size() - 1);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < sub_names.size() - 1; ++i) {
|
||||||
|
bool found = false;
|
||||||
|
const auto& sub_name = sub_names[i];
|
||||||
|
for (const auto& child_module : cur_module.named_children()) {
|
||||||
|
if (child_module.name == sub_name) {
|
||||||
|
attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
|
||||||
|
cur_module = child_module.value;
|
||||||
|
found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!found) {
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
|
||||||
|
// We don't want to mark these modules as mutable yet; that could
|
||||||
|
// interfere with the inlining procedure. Instead, we'll record
|
||||||
|
// the fact that the user wants to preserve them. They will be
|
||||||
|
// processed during clean-up preparation (recordReferenceAttrs)
|
||||||
|
for (auto& attr_info : attr_infos) {
|
||||||
|
const auto& parent_module = attr_info.first;
|
||||||
|
auto& sub_name = attr_info.second;
|
||||||
|
userPreservedAttrs_[parent_module._ivalue()].insert(
|
||||||
|
std::move(sub_name));
|
||||||
|
}
|
||||||
|
return std::make_pair(std::move(cur_module), std::move(attr_name));
|
||||||
|
}
|
||||||
|
|
||||||
|
return c10::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
// findConstantAttr function locates the sub Module where attributes are
|
// findConstantAttr function locates the sub Module where attributes are
|
||||||
// defined. The algorithm chases getAttr chains to locate the submodules.
|
// defined. The algorithm chases getAttr chains to locate the submodules.
|
||||||
// For example:
|
// For example:
|
||||||
|
|
@ -638,6 +699,16 @@ class AttributePropagator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// We have to process the attributes that the user wants to preserve
|
||||||
|
// separately since it's possible that the user-preserved module is
|
||||||
|
// never referenced in the graph.
|
||||||
|
for (const auto& attr_info : userPreservedAttrs_) {
|
||||||
|
const auto& parent_module = attr_info.first;
|
||||||
|
for (const auto& attr_name : attr_info.second) {
|
||||||
|
const auto value = parent_module->getAttr(attr_name);
|
||||||
|
insertMutableAttr(attr_name, value, parent_module);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function recursively iterates over submodules to identify
|
// This function recursively iterates over submodules to identify
|
||||||
|
|
@ -710,7 +781,7 @@ class AttributePropagator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto& fn : type->methods()) {
|
for (auto& fn : type->methods()) {
|
||||||
if (preservedMethods_.count(fn) && *type == *module_.type()) {
|
if (preservedMethods_.count(fn)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
funcsToRemove.push_back(fn);
|
funcsToRemove.push_back(fn);
|
||||||
|
|
@ -774,6 +845,11 @@ class AttributePropagator {
|
||||||
c10::intrusive_ptr<at::ivalue::Object>>
|
c10::intrusive_ptr<at::ivalue::Object>>
|
||||||
object_memo_;
|
object_memo_;
|
||||||
|
|
||||||
|
// Contains names of attributes that the user wants to preserve with
|
||||||
|
// their owning modules.
|
||||||
|
std::unordered_map<ModulePtr, std::unordered_set<std::string>>
|
||||||
|
userPreservedAttrs_;
|
||||||
|
|
||||||
}; // class AttributePropagator
|
}; // class AttributePropagator
|
||||||
|
|
||||||
void checkModuleDoesNotReturnSelf(const Module& module) {
|
void checkModuleDoesNotReturnSelf(const Module& module) {
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,7 @@
|
||||||
#include <torch/csrc/jit/runtime/autodiff.h>
|
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||||
|
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||||
#include <torch/csrc/jit/runtime/operator.h>
|
#include <torch/csrc/jit/runtime/operator.h>
|
||||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||||
#include <torch/csrc/jit/runtime/static/init.h>
|
#include <torch/csrc/jit/runtime/static/init.h>
|
||||||
|
|
@ -520,6 +521,22 @@ void initJITBindings(PyObject* module) {
|
||||||
},
|
},
|
||||||
py::doc(
|
py::doc(
|
||||||
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
|
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
|
||||||
|
.def(
|
||||||
|
"_jit_trace_graph",
|
||||||
|
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
|
||||||
|
Stack stack;
|
||||||
|
stack.reserve(inputs.size()); // captures?
|
||||||
|
for (auto& obj : inputs) {
|
||||||
|
stack.push_back(toTypeInferredIValue(obj));
|
||||||
|
}
|
||||||
|
auto g_inputs = graph->inputs();
|
||||||
|
for (const auto i : c10::irange(inputs.size())) {
|
||||||
|
if (stack[i].isTensor()) {
|
||||||
|
g_inputs[i]->setType(stack[i].type());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return TraceGraph(graph, stack);
|
||||||
|
})
|
||||||
.def("_jit_pass_remove_expands", RemoveExpands)
|
.def("_jit_pass_remove_expands", RemoveExpands)
|
||||||
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
||||||
.def("_jit_pass_inline_fork_wait", InlineForkWait)
|
.def("_jit_pass_inline_fork_wait", InlineForkWait)
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user