mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update on "[Pytorch][Bootcamp] Add fix and testing for vectorized Adadelta optimizer to handle complex numbers"
Made changes in the step function of the vectorized Adadelta optimizer to handle complex numbers as two real numbers as per 65711 on github. Differential Revision: [D31631870](https://our.internmc.facebook.com/intern/diff/D31631870/) [ghstack-poisoned]
This commit is contained in:
commit
f2aa06c17a
|
|
@ -10,7 +10,6 @@ CONFIG_TREE_DATA = [
|
|||
]),
|
||||
]),
|
||||
# TODO: bring back libtorch test
|
||||
("7", [X("3.6")]),
|
||||
]),
|
||||
("cuda", [
|
||||
("10.2", [
|
||||
|
|
|
|||
44
.circleci/config.yml
generated
44
.circleci/config.yml
generated
|
|
@ -6582,31 +6582,6 @@ workflows:
|
|||
name: pytorch_cpp_doc_push
|
||||
requires:
|
||||
- pytorch_cpp_doc_build
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_6_gcc7_build
|
||||
requires:
|
||||
- "docker-pytorch-linux-xenial-py3.6-gcc7"
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
||||
- pytorch_linux_test:
|
||||
name: pytorch_linux_xenial_py3_6_gcc7_test
|
||||
requires:
|
||||
- pytorch_linux_xenial_py3_6_gcc7_build
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- /ci-all\/.*/
|
||||
- /release\/.*/
|
||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
||||
resource_class: large
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
||||
requires:
|
||||
|
|
@ -8334,9 +8309,6 @@ workflows:
|
|||
only: /.*/
|
||||
tags:
|
||||
only: /v[0-9]+(\.[0-9]+)*-rc[0-9]+/
|
||||
- docker_build_job:
|
||||
name: "docker-pytorch-linux-xenial-py3.6-gcc7"
|
||||
image_name: "pytorch-linux-xenial-py3.6-gcc7"
|
||||
when: << pipeline.parameters.run_build >>
|
||||
master_build:
|
||||
jobs:
|
||||
|
|
@ -8352,19 +8324,6 @@ workflows:
|
|||
- pytorch_cpp_doc_build:
|
||||
requires:
|
||||
- pytorch_linux_xenial_py3_6_gcc5_4_build
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_py3_6_gcc7_build
|
||||
requires:
|
||||
- "docker-pytorch-linux-xenial-py3.6-gcc7"
|
||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-build"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
||||
- pytorch_linux_test:
|
||||
name: pytorch_linux_xenial_py3_6_gcc7_test
|
||||
requires:
|
||||
- pytorch_linux_xenial_py3_6_gcc7_build
|
||||
build_environment: "pytorch-linux-xenial-py3.6-gcc7-test"
|
||||
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7"
|
||||
resource_class: large
|
||||
- pytorch_linux_build:
|
||||
name: pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build
|
||||
requires:
|
||||
|
|
@ -8474,9 +8433,6 @@ workflows:
|
|||
- docker_build_job:
|
||||
name: "docker-pytorch-linux-xenial-py3.6-gcc5.4"
|
||||
image_name: "pytorch-linux-xenial-py3.6-gcc5.4"
|
||||
- docker_build_job:
|
||||
name: "docker-pytorch-linux-xenial-py3.6-gcc7"
|
||||
image_name: "pytorch-linux-xenial-py3.6-gcc7"
|
||||
when: << pipeline.parameters.run_master_build >>
|
||||
ecr_gc:
|
||||
triggers:
|
||||
|
|
|
|||
4
.github/generated-ciflow-ruleset.json
generated
vendored
4
.github/generated-ciflow-ruleset.json
generated
vendored
|
|
@ -17,6 +17,7 @@
|
|||
"linux-xenial-py3.6-clang7-asan",
|
||||
"linux-xenial-py3.6-clang7-onnx",
|
||||
"linux-xenial-py3.6-gcc5.4",
|
||||
"linux-xenial-py3.6-gcc7",
|
||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
||||
|
|
@ -36,6 +37,7 @@
|
|||
"linux-xenial-py3.6-clang7-asan",
|
||||
"linux-xenial-py3.6-clang7-onnx",
|
||||
"linux-xenial-py3.6-gcc5.4",
|
||||
"linux-xenial-py3.6-gcc7",
|
||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||
"win-vs2019-cpu-py3"
|
||||
|
|
@ -62,6 +64,7 @@
|
|||
"linux-xenial-py3.6-clang7-asan",
|
||||
"linux-xenial-py3.6-clang7-onnx",
|
||||
"linux-xenial-py3.6-gcc5.4",
|
||||
"linux-xenial-py3.6-gcc7",
|
||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||
"win-vs2019-cpu-py3",
|
||||
"win-vs2019-cuda11.3-py3"
|
||||
|
|
@ -87,6 +90,7 @@
|
|||
"linux-xenial-py3.6-clang7-asan",
|
||||
"linux-xenial-py3.6-clang7-onnx",
|
||||
"linux-xenial-py3.6-gcc5.4",
|
||||
"linux-xenial-py3.6-gcc7",
|
||||
"linux-xenial-py3.6-gcc7-bazel-test",
|
||||
"parallelnative-linux-xenial-py3.6-gcc5.4",
|
||||
"periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7",
|
||||
|
|
|
|||
11
.github/scripts/generate_ci_workflows.py
vendored
11
.github/scripts/generate_ci_workflows.py
vendored
|
|
@ -274,6 +274,17 @@ LINUX_WORKFLOWS = [
|
|||
labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}
|
||||
),
|
||||
),
|
||||
CIWorkflow(
|
||||
arch="linux",
|
||||
build_environment="linux-xenial-py3.6-gcc7",
|
||||
docker_image_base=f"{DOCKER_REGISTRY}/pytorch/pytorch-linux-xenial-py3.6-gcc7",
|
||||
test_runner_type=LINUX_CPU_TEST_RUNNER,
|
||||
num_test_shards=2,
|
||||
ciflow_config=CIFlowConfig(
|
||||
run_on_canary=True,
|
||||
labels={LABEL_CIFLOW_DEFAULT, LABEL_CIFLOW_LINUX, LABEL_CIFLOW_CPU}
|
||||
),
|
||||
),
|
||||
# ParallelTBB does not have a maintainer and is currently flaky
|
||||
# CIWorkflow(
|
||||
# arch="linux",
|
||||
|
|
|
|||
517
.github/workflows/generated-linux-xenial-py3.6-gcc7.yml
generated
vendored
Normal file
517
.github/workflows/generated-linux-xenial-py3.6-gcc7.yml
generated
vendored
Normal file
|
|
@ -0,0 +1,517 @@
|
|||
# @generated DO NOT EDIT MANUALLY
|
||||
# Template is at: .github/templates/linux_ci_workflow.yml.j2
|
||||
# Generation script: .github/scripts/generate_ci_workflows.py
|
||||
name: linux-xenial-py3.6-gcc7
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, unassigned]
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- release/*
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
BUILD_ENVIRONMENT: linux-xenial-py3.6-gcc7
|
||||
DOCKER_IMAGE_BASE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7
|
||||
SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
|
||||
XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla
|
||||
TORCH_CUDA_ARCH_LIST: 5.2
|
||||
IN_CI: 1
|
||||
IS_GHA: 1
|
||||
# This is used for the phase of adding wheel tests only, will be removed once completed
|
||||
IN_WHEEL_TEST: 1
|
||||
# Used for custom_opertor, jit_hooks, custom_backend, see .jenkins/pytorch/build.sh
|
||||
CUSTOM_TEST_ARTIFACT_BUILD_DIR: build/custom_test_artifacts
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
concurrency:
|
||||
group: linux-xenial-py3.6-gcc7-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
|
||||
ciflow_should_run:
|
||||
runs-on: ubuntu-18.04
|
||||
env:
|
||||
IS_PROBOT_TRIGGER_EVENT: ${{ (github.event.action == 'unassigned') && (github.event.assigneed.login == 'pytorchbot') }}
|
||||
LABEL_CONDITIONS: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux') }}
|
||||
LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }}
|
||||
if: ${{ (github.repository_owner == 'pytorch') && (
|
||||
(github.event_name == 'push') ||
|
||||
(github.event_name == 'schedule') ||
|
||||
(contains(github.event.pull_request.labels.*.name, 'ciflow/all') || contains(github.event.pull_request.labels.*.name, 'ciflow/cpu') || contains(github.event.pull_request.labels.*.name, 'ciflow/default') || contains(github.event.pull_request.labels.*.name, 'ciflow/linux')) ||
|
||||
((github.event_name == 'pull_request' && github.event.action != 'unassigned') && !contains(join(github.event.pull_request.labels.*.name), 'ciflow/')))
|
||||
}}
|
||||
steps:
|
||||
- name: noop
|
||||
run: echo running ciflow_should_run
|
||||
- name: print labels
|
||||
run: echo "${LABELS}"
|
||||
|
||||
build:
|
||||
runs-on: linux.2xlarge
|
||||
needs: [ciflow_should_run]
|
||||
env:
|
||||
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-build
|
||||
outputs:
|
||||
docker_image: ${{ steps.calculate-tag.outputs.docker_image }}
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
- name: Log in to ECR
|
||||
env:
|
||||
AWS_RETRY_MODE: standard
|
||||
AWS_MAX_ATTEMPTS: 5
|
||||
run: |
|
||||
aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh
|
||||
bash /tmp/ecr-login.sh
|
||||
rm /tmp/ecr-login.sh
|
||||
- name: Chown workspace
|
||||
env:
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
run: |
|
||||
retry () {
|
||||
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||
}
|
||||
retry docker pull "${ALPINE_IMAGE}"
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --pull=never --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "${GITHUB_WORKSPACE:?}/*"
|
||||
rm -f ~/.ssh/authorized_keys
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: seemethere/add-github-ssh-key@v1
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Preserve github env variables for use in docker
|
||||
run: |
|
||||
env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
- name: Checkout PyTorch
|
||||
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||
with:
|
||||
# deep clone, to allow use of git merge-base
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Calculate docker image tag
|
||||
id: calculate-tag
|
||||
run: |
|
||||
DOCKER_TAG=$(git rev-parse HEAD:.circleci/docker)
|
||||
echo "DOCKER_TAG=${DOCKER_TAG}" >> "${GITHUB_ENV}"
|
||||
echo "DOCKER_IMAGE=${DOCKER_IMAGE_BASE}:${DOCKER_TAG}" >> "${GITHUB_ENV}"
|
||||
echo "::set-output name=docker_tag::${DOCKER_TAG}"
|
||||
echo "::set-output name=docker_image::${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"
|
||||
- name: Check if image should be built
|
||||
id: check
|
||||
env:
|
||||
BASE_REVISION: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||
run: |
|
||||
set -x
|
||||
# Check if image already exists, if it does then skip building it
|
||||
if docker manifest inspect "${DOCKER_IMAGE_BASE}:${DOCKER_TAG}"; then
|
||||
exit 0
|
||||
fi
|
||||
if [[ "$BASE_REVISION" = "$(git rev-parse HEAD)" ]]; then
|
||||
# if we're on the base branch then use the parent commit
|
||||
MERGE_BASE=$(git rev-parse HEAD~)
|
||||
else
|
||||
# otherwise we're on a PR, so use the most recent base commit
|
||||
MERGE_BASE=$(git merge-base HEAD "$BASE_REVISION")
|
||||
fi
|
||||
# Covers the case where a previous tag doesn't exist for the tree
|
||||
# this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly
|
||||
if ! git rev-parse "$MERGE_BASE:.circleci/docker"; then
|
||||
echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit"
|
||||
exit 1
|
||||
fi
|
||||
PREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker")
|
||||
# If no image exists but the hash is the same as the previous hash then we should error out here
|
||||
if [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then
|
||||
echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch"
|
||||
echo " contact the PyTorch team to restore the original images"
|
||||
exit 1
|
||||
fi
|
||||
echo ::set-output name=rebuild::yes
|
||||
- name: Build and push docker image
|
||||
if: ${{ steps.check.outputs.rebuild }}
|
||||
env:
|
||||
DOCKER_SKIP_S3_UPLOAD: 1
|
||||
working-directory: .circleci/docker
|
||||
run: |
|
||||
export IMAGE_NAME=${DOCKER_IMAGE_BASE#308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/}
|
||||
./build_docker.sh
|
||||
- name: Pull Docker image
|
||||
run: |
|
||||
retry () {
|
||||
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||
}
|
||||
retry docker pull "${DOCKER_IMAGE}"
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
- name: Build
|
||||
env:
|
||||
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
run: |
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
container_name=$(docker run \
|
||||
-e BUILD_ENVIRONMENT \
|
||||
-e JOB_BASE_NAME \
|
||||
-e MAX_JOBS="$(nproc --ignore=2)" \
|
||||
-e AWS_DEFAULT_REGION \
|
||||
-e IS_GHA \
|
||||
-e CIRCLE_PR_NUMBER \
|
||||
-e CIRCLE_SHA1 \
|
||||
-e CIRCLE_BRANCH \
|
||||
-e GITHUB_RUN_ID \
|
||||
-e SCCACHE_BUCKET \
|
||||
-e XLA_CLANG_CACHE_S3_BUCKET_NAME \
|
||||
-e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
|
||||
-e SKIP_SCCACHE_INITIALIZATION=1 \
|
||||
-e TORCH_CUDA_ARCH_LIST \
|
||||
-e PR_LABELS \
|
||||
-e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,github.com,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--tty \
|
||||
--detach \
|
||||
--user jenkins \
|
||||
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t "${container_name}" sh -c 'sudo chown -R jenkins . && .jenkins/pytorch/build.sh'
|
||||
- name: Display and upload binary build size statistics (Click Me)
|
||||
# temporary hack: set CIRCLE_* vars, until we update
|
||||
# tools/stats/print_test_stats.py to natively support GitHub Actions
|
||||
env:
|
||||
SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }}
|
||||
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
|
||||
CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
|
||||
run: |
|
||||
COMMIT_TIME=$(git log --max-count=1 --format=%ct || echo 0)
|
||||
export COMMIT_TIME
|
||||
pip3 install requests==2.26 boto3==1.16.34
|
||||
python3 -m tools.stats.upload_binary_size_to_scuba || exit 0
|
||||
- name: Chown workspace
|
||||
run: |
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Archive artifacts into zip
|
||||
run: |
|
||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
|
||||
- uses: seemethere/upload-artifact-s3@v3
|
||||
name: Store PyTorch Build Artifacts on S3
|
||||
with:
|
||||
name: ${{ env.BUILD_ENVIRONMENT }}
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path:
|
||||
artifacts.zip
|
||||
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||
# Always hold for active ssh sessions
|
||||
if: always()
|
||||
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||
- name: Chown workspace
|
||||
if: always()
|
||||
env:
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
run: |
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Kill containers, clean up images
|
||||
if: always()
|
||||
run: |
|
||||
# ignore expansion of "docker ps -q" since it could be empty
|
||||
# shellcheck disable=SC2046
|
||||
docker stop $(docker ps -q) || true
|
||||
# Prune all of the docker images
|
||||
docker system prune -af
|
||||
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||
# Always hold for active ssh sessions
|
||||
if: always()
|
||||
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||
- name: Clean up docker images
|
||||
if: always()
|
||||
run: |
|
||||
# Prune all of the docker images
|
||||
docker system prune -af
|
||||
|
||||
generate-test-matrix:
|
||||
runs-on: ubuntu-18.04
|
||||
needs: [ciflow_should_run]
|
||||
env:
|
||||
TEST_RUNNER_TYPE: linux.2xlarge
|
||||
ENABLE_DISTRIBUTED_TEST: 1
|
||||
ENABLE_JIT_LEGACY_TEST: ''
|
||||
ENABLE_MULTIGPU_TEST: ''
|
||||
ENABLE_NOGPU_NO_AVX_TEST: ''
|
||||
ENABLE_NOGPU_NO_AVX2_TEST: ''
|
||||
ENABLE_SLOW_TEST: ''
|
||||
ENABLE_DOCS_TEST: ''
|
||||
ENABLE_BACKWARDS_COMPAT_TEST: ''
|
||||
ENABLE_XLA_TEST: ''
|
||||
ENABLE_NOARCH_TEST: ''
|
||||
NUM_TEST_SHARDS: 2
|
||||
MULTIGPU_RUNNER_TYPE: linux.16xlarge.nvidia.gpu
|
||||
NOGPU_RUNNER_TYPE: linux.2xlarge
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
outputs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
render-matrix: ${{ steps.set-matrix.outputs.render-matrix }}
|
||||
ignore-disabled-issues: ${{ steps.set-matrix.outputs.ignore-disabled-issues }}
|
||||
container:
|
||||
image: python:3.9
|
||||
steps:
|
||||
- name: Install dependencies
|
||||
run: pip install typing-extensions==3.10
|
||||
- name: Clone pytorch/pytorch
|
||||
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||
- name: Generating test matrix
|
||||
id: set-matrix
|
||||
run: .github/scripts/generate_pytorch_test_matrix.py
|
||||
|
||||
test:
|
||||
needs: [build, generate-test-matrix, ciflow_should_run]
|
||||
strategy:
|
||||
matrix: ${{ fromJson(needs.generate-test-matrix.outputs.matrix) }}
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.runner }}
|
||||
env:
|
||||
DOCKER_IMAGE: ${{ needs.build.outputs.docker_image }}
|
||||
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-test
|
||||
TEST_CONFIG: ${{ matrix.config }}
|
||||
SHARD_NUMBER: ${{ matrix.shard }}
|
||||
NUM_TEST_SHARDS: ${{ matrix.num_shards }}
|
||||
PYTORCH_IGNORE_DISABLED_ISSUES: ${{ needs.generate-test-matrix.outputs.ignore-disabled-issues }}
|
||||
steps:
|
||||
- name: Display EC2 information
|
||||
shell: bash
|
||||
run: |
|
||||
set -euo pipefail
|
||||
function get_ec2_metadata() {
|
||||
# Pulled from instance metadata endpoint for EC2
|
||||
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
|
||||
category=$1
|
||||
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
|
||||
}
|
||||
echo "ami-id: $(get_ec2_metadata ami-id)"
|
||||
echo "instance-id: $(get_ec2_metadata instance-id)"
|
||||
echo "instance-type: $(get_ec2_metadata instance-type)"
|
||||
- name: Log in to ECR
|
||||
env:
|
||||
AWS_RETRY_MODE: standard
|
||||
AWS_MAX_ATTEMPTS: 5
|
||||
run: |
|
||||
aws ecr get-login --no-include-email --region us-east-1 > /tmp/ecr-login.sh
|
||||
bash /tmp/ecr-login.sh
|
||||
rm /tmp/ecr-login.sh
|
||||
- name: Chown workspace
|
||||
env:
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
run: |
|
||||
retry () {
|
||||
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||
}
|
||||
retry docker pull "${ALPINE_IMAGE}"
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --pull=never --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Clean workspace
|
||||
run: |
|
||||
rm -rf "${GITHUB_WORKSPACE:?}/*"
|
||||
rm -f ~/.ssh/authorized_keys
|
||||
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
|
||||
uses: seemethere/add-github-ssh-key@v1
|
||||
with:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Preserve github env variables for use in docker
|
||||
run: |
|
||||
env | grep '^GITHUB' > "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
- name: Checkout PyTorch
|
||||
uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9
|
||||
with:
|
||||
# deep clone, to allow use of git merge-base
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Pull Docker image
|
||||
run: |
|
||||
retry () {
|
||||
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
|
||||
}
|
||||
retry docker pull "${DOCKER_IMAGE}"
|
||||
- name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG
|
||||
if: ${{ contains(env.BUILD_ENVIRONMENT, 'cuda') && !contains(matrix.config, 'nogpu') }}
|
||||
run: |
|
||||
bash .github/scripts/install_nvidia_utils_linux.sh
|
||||
echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}"
|
||||
- name: Determine shm-size
|
||||
run: |
|
||||
shm_size="1g"
|
||||
case "${BUILD_ENVIRONMENT}" in
|
||||
*cuda*)
|
||||
shm_size="2g"
|
||||
;;
|
||||
*rocm*)
|
||||
shm_size="8g"
|
||||
;;
|
||||
esac
|
||||
echo "SHM_SIZE=${shm_size}" >> "${GITHUB_ENV}"
|
||||
- uses: seemethere/download-artifact-s3@0504774707cbc8603d7dca922e8026eb8bf3b47b
|
||||
name: Download PyTorch Build Artifacts
|
||||
with:
|
||||
name: ${{ env.BUILD_ENVIRONMENT }}
|
||||
- name: Unzip artifacts
|
||||
run: |
|
||||
unzip -o artifacts.zip
|
||||
- name: Output disk space left
|
||||
run: |
|
||||
sudo df -H
|
||||
- name: Parse ref
|
||||
id: parse-ref
|
||||
run: .github/scripts/parse_ref.py
|
||||
- name: Test
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
# Time out the test phase after 240 minutes
|
||||
timeout-minutes: 240
|
||||
run: |
|
||||
set -x
|
||||
|
||||
if [[ $TEST_CONFIG == 'multigpu' ]]; then
|
||||
TEST_COMMAND=.jenkins/pytorch/multigpu-test.sh
|
||||
elif [[ $BUILD_ENVIRONMENT == *onnx* ]]; then
|
||||
TEST_COMMAND=.jenkins/caffe2/test.sh
|
||||
else
|
||||
TEST_COMMAND=.jenkins/pytorch/test.sh
|
||||
fi
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
# TODO: Stop building test binaries as part of the build phase
|
||||
# Used for GPU_FLAG since that doesn't play nice
|
||||
# shellcheck disable=SC2086
|
||||
container_name=$(docker run \
|
||||
${GPU_FLAG:-} \
|
||||
-e BUILD_ENVIRONMENT \
|
||||
-e PR_NUMBER \
|
||||
-e CUSTOM_TEST_ARTIFACT_BUILD_DIR \
|
||||
-e GITHUB_ACTIONS \
|
||||
-e IN_CI \
|
||||
-e IS_GHA \
|
||||
-e CIRCLE_BRANCH \
|
||||
-e CIRCLE_SHA1 \
|
||||
-e CIRCLE_PR_NUMBER \
|
||||
-e AWS_DEFAULT_REGION \
|
||||
-e IN_WHEEL_TEST \
|
||||
-e SHARD_NUMBER \
|
||||
-e JOB_BASE_NAME \
|
||||
-e TEST_CONFIG \
|
||||
-e NUM_TEST_SHARDS \
|
||||
-e PYTORCH_IGNORE_DISABLED_ISSUES \
|
||||
-e PR_LABELS \
|
||||
-e MAX_JOBS="$(nproc --ignore=2)" \
|
||||
-e SCCACHE_BUCKET \
|
||||
-e http_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e https_proxy="http://internal-tf-lb-20210727220640487900000002-835786077.us-east-1.elb.amazonaws.com:3128" -e no_proxy="localhost,127.0.0.1,github.com,amazonaws.com,s3.amazonaws.com,169.254.169.254,169.254.170.2,/var/run/docker.sock" \
|
||||
-e XLA_CLANG_CACHE_S3_BUCKET_NAME \
|
||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||
--ulimit stack=10485760:83886080 \
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--shm-size="${SHM_SIZE}" \
|
||||
--tty \
|
||||
--detach \
|
||||
--name="${container_name}" \
|
||||
--user jenkins \
|
||||
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${DOCKER_IMAGE}"
|
||||
)
|
||||
docker exec -t "${container_name}" sh -c "sudo chown -R jenkins . && pip install dist/*.whl && ${TEST_COMMAND}"
|
||||
- name: Chown workspace
|
||||
if: always()
|
||||
run: |
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Install render_test_results dependencies
|
||||
if: always()
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install junitparser==2.1.1 rich==10.9.0
|
||||
- name: "[[ Click me for rendered test results (useful for finding failing tests) ]]"
|
||||
if: always()
|
||||
shell: bash
|
||||
# Encoding is weird on windows, just try to default to utf-8 if possible
|
||||
env:
|
||||
PYTHONIOENCODING: "utf-8"
|
||||
run: |
|
||||
python3 tools/render_junit.py test/
|
||||
- name: Zip test reports for upload
|
||||
if: always()
|
||||
env:
|
||||
FILE_SUFFIX: '${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}'
|
||||
run: |
|
||||
# Remove any previous test reports if they exist
|
||||
rm -f test-reports-*.zip
|
||||
zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml'
|
||||
- uses: seemethere/upload-artifact-s3@v3
|
||||
name: Store Test Reports on S3
|
||||
if: always()
|
||||
with:
|
||||
retention-days: 14
|
||||
if-no-files-found: error
|
||||
path:
|
||||
test-reports-*.zip
|
||||
- name: Display and upload test statistics (Click Me)
|
||||
if: always()
|
||||
# temporary hack: set CIRCLE_* vars, until we update
|
||||
# tools/stats/print_test_stats.py to natively support GitHub Actions
|
||||
env:
|
||||
AWS_DEFAULT_REGION: us-east-1
|
||||
CIRCLE_BRANCH: ${{ steps.parse-ref.outputs.branch }}
|
||||
JOB_BASE_NAME: linux-xenial-py3.6-gcc7-test
|
||||
CIRCLE_PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
CIRCLE_SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
CIRCLE_TAG: ${{ steps.parse-ref.outputs.tag }}
|
||||
CIRCLE_WORKFLOW_ID: '${{ github.run_id }}_${{ github.run_number }}'
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install -r requirements.txt
|
||||
python3 -m pip install boto3==1.16.34
|
||||
python3 -m tools.stats.print_test_stats --upload-to-s3 --compare-with-s3 test
|
||||
- name: Hold runner for 2 hours or until ssh sessions have drained
|
||||
# Always hold for active ssh sessions
|
||||
if: always()
|
||||
run: .github/scripts/wait_for_ssh_to_drain.sh
|
||||
- name: Chown workspace
|
||||
if: always()
|
||||
env:
|
||||
ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine"
|
||||
run: |
|
||||
# Ensure the working directory gets chowned back to the current user
|
||||
docker run --rm -v "$(pwd)":/v -w /v "${ALPINE_IMAGE}" chown -R "$(id -u):$(id -g)" .
|
||||
- name: Kill containers, clean up images
|
||||
if: always()
|
||||
run: |
|
||||
# ignore expansion of "docker ps -q" since it could be empty
|
||||
# shellcheck disable=SC2046
|
||||
docker stop $(docker ps -q) || true
|
||||
# Prune all of the docker images
|
||||
docker system prune -af
|
||||
|
|
@ -327,8 +327,10 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
|||
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" AND CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||
endif()
|
||||
endif()
|
||||
if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND
|
||||
CMAKE_C_COMPILER_VERSION VERSION_GREATER 6.9 AND CMAKE_C_COMPILER_VERSION VERSION_LESS 8)
|
||||
|
|
|
|||
|
|
@ -544,10 +544,11 @@ static void check_shape_forward(const at::Tensor& input,
|
|||
bool kernel_size_correct = true;
|
||||
|
||||
TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups),
|
||||
"Given groups=", groups, ", weight of size ", weight_sizes,
|
||||
", expected input", input.sizes(), " to have ",
|
||||
(weight_sizes[1] * groups), " channels, but got ", input.size(1),
|
||||
" channels instead");
|
||||
"Given groups=", groups, ", weight of size ", weight_sizes,
|
||||
", expected input", input.sizes(), " to have ",
|
||||
(weight_sizes[1] * groups), " channels, but got ", input.size(1),
|
||||
" channels instead");
|
||||
|
||||
TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]),
|
||||
"Given weight of size ", weight_sizes,
|
||||
", expected bias to be 1-dimensional with ", weight_sizes[0], " elements",
|
||||
|
|
@ -850,7 +851,7 @@ at::Tensor _convolution(
|
|||
|
||||
check_shape_forward(input, weight_sizes, bias, params);
|
||||
|
||||
if (input.size(0) == 0) {
|
||||
if (input.size(0) == 0 || input.size(1) == 0) {
|
||||
// don't send empty inputs through backends
|
||||
// but need to compute correct output size first and set up history for params
|
||||
std::vector<int64_t> o;
|
||||
|
|
@ -862,6 +863,9 @@ at::Tensor _convolution(
|
|||
params.output_padding, params.stride, params.dilation,
|
||||
params.groups);
|
||||
}
|
||||
if (input.size(1) == 0) {
|
||||
o[input_channels_dim] = 0;
|
||||
}
|
||||
if (input_is_mkldnn && weight.is_mkldnn()) {
|
||||
// mkldnn will error on the below 0-dim handling code
|
||||
return empty_mkldnn(
|
||||
|
|
@ -871,10 +875,12 @@ at::Tensor _convolution(
|
|||
input.options().device_opt(),
|
||||
input.options().pinned_memory_opt());
|
||||
}
|
||||
|
||||
auto weight_view = at::_unsafe_view(weight, -1);
|
||||
auto out = input*weight_view[0];
|
||||
if (bias.defined())
|
||||
auto out = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]);
|
||||
if (bias.defined()) {
|
||||
out.add_(bias[0]);
|
||||
}
|
||||
return out.view(o);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,20 +52,17 @@ static inline void slow_conv2d_shape_check(
|
|||
}
|
||||
|
||||
const int64_t ndim = input.dim();
|
||||
const int64_t dim_batch = 0;
|
||||
const int64_t dim_planes = 1;
|
||||
const int64_t dim_height = 2;
|
||||
const int64_t dim_width = 3;
|
||||
|
||||
// Allow for empty batch size but not other dimensions
|
||||
bool valid_empty = ndim == 4 && input.size(dim_batch) == 0 &&
|
||||
input.size(dim_planes) != 0 && input.size(dim_height) != 0 &&
|
||||
input.size(dim_width) != 0;
|
||||
|
||||
TORCH_CHECK(
|
||||
(input.numel() > 0 || valid_empty) && ndim == 4,
|
||||
"non-empty 4D input tensor expected but got: ",
|
||||
input.sizes());
|
||||
// Allow for empty batch size and channel size but not other dimensions
|
||||
TORCH_CHECK(ndim == 4, "Expected 4D input tensor, but got: ", input.sizes());
|
||||
for (int64_t dim = 2; dim < ndim; ++dim) {
|
||||
TORCH_CHECK(input.size(dim) != 0,
|
||||
"Expected non-zero size for input dimension ", dim,
|
||||
", but got input shape: ", input.sizes(), ". Only the batch and channel dimensions support size 0.");
|
||||
}
|
||||
|
||||
const int64_t input_height = input.size(dim_height);
|
||||
const int64_t input_width = input.size(dim_width);
|
||||
|
|
@ -109,7 +106,9 @@ static inline void slow_conv2d_shape_check(
|
|||
if (weight.dim() == 2) {
|
||||
n_input_plane /= (kernel_height * kernel_width);
|
||||
}
|
||||
check_dim_size(input, ndim, dim_planes, n_input_plane);
|
||||
if (input.size(1) != 0) {
|
||||
check_dim_size(input, ndim, dim_planes, n_input_plane);
|
||||
}
|
||||
}
|
||||
|
||||
if (grad_output.defined()) {
|
||||
|
|
@ -529,6 +528,7 @@ std::tuple<Tensor, Tensor> slow_conv2d_forward_cpu(
|
|||
padding,
|
||||
output,
|
||||
finput);
|
||||
|
||||
return std::make_tuple(output, finput);
|
||||
}
|
||||
|
||||
|
|
@ -559,6 +559,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> slow_conv2d_backward_out_cpu(
|
|||
at::sum_out(grad_bias, grad_output, IntArrayRef{0, 2, 3});
|
||||
}
|
||||
|
||||
|
||||
if (grad_weight.defined()) {
|
||||
grad_weight.resize_(weight.sizes());
|
||||
grad_weight.zero_();
|
||||
|
|
|
|||
|
|
@ -82,11 +82,11 @@ static void BM_deep_wide_static(benchmark::State& state) {
|
|||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||
auto wide = torch::randn({batch_size, num_features});
|
||||
|
||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
||||
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
for (auto _ : state) {
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -104,11 +104,11 @@ static void BM_deep_wide_static_threaded(benchmark::State& state) {
|
|||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||
auto wide = torch::randn({batch_size, num_features});
|
||||
|
||||
std::vector<at::Tensor> inputs({ad_emb_packed, user_emb, wide});
|
||||
std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
|
||||
|
||||
sr(inputs);
|
||||
sr(inputs, {});
|
||||
for (auto _ : state) {
|
||||
sr(inputs);
|
||||
sr(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -118,11 +118,11 @@ static void BM_leaky_relu_const(benchmark::State& state) {
|
|||
|
||||
const int batch_size = state.range(0);
|
||||
auto data = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> inputs({data});
|
||||
std::vector<c10::IValue> inputs({data});
|
||||
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
for (auto _ : state) {
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -133,11 +133,11 @@ static void BM_leaky_relu(benchmark::State& state) {
|
|||
const int batch_size = state.range(0);
|
||||
auto neg_slope = torch::randn(1);
|
||||
auto data = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> inputs({data, neg_slope[0]});
|
||||
std::vector<c10::IValue> inputs({data, neg_slope[0]});
|
||||
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
for (auto _ : state) {
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -150,11 +150,11 @@ static void BM_signed_log1p(benchmark::State& state) {
|
|||
|
||||
const int num_elements = state.range(0);
|
||||
auto data = torch::randn({num_elements});
|
||||
std::vector<at::Tensor> inputs({data});
|
||||
std::vector<c10::IValue> inputs({data});
|
||||
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
for (auto _ : state) {
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -170,11 +170,11 @@ static void BM_long_static_memory_optimization(benchmark::State& state) {
|
|||
auto a = torch::randn({N, N});
|
||||
auto b = torch::randn({N, N});
|
||||
auto c = torch::randn({N, N});
|
||||
std::vector<at::Tensor> inputs({a, b, c});
|
||||
std::vector<c10::IValue> inputs({a, b, c});
|
||||
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
for (auto _ : state) {
|
||||
smod(inputs);
|
||||
smod(inputs, {});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -755,9 +755,9 @@ TEST(StaticRuntime, LongModel) {
|
|||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
||||
std::vector<c10::IValue> input_tensors({a, b, c});
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
|
@ -773,9 +773,9 @@ TEST(StaticRuntime, TrivialModel) {
|
|||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({a, b, c});
|
||||
std::vector<c10::IValue> input_tensors({a, b, c});
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
|
@ -789,9 +789,9 @@ TEST(StaticRuntime, LeakyReLU) {
|
|||
at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({inputs});
|
||||
std::vector<c10::IValue> input_tensors({inputs});
|
||||
torch::jit::StaticModule smod(mod);
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
at::Tensor output_2 = smod(input_tensors, {}).toTensor();
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
|
@ -813,8 +813,10 @@ TEST(StaticRuntime, DeepWide) {
|
|||
auto output_1 = getTensor(mod.forward(inputs));
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors({ad_emb_packed, user_emb, wide});
|
||||
at::Tensor output_2 = smod(input_tensors)[0];
|
||||
std::vector<c10::IValue> input_tensors({ad_emb_packed, user_emb, wide});
|
||||
auto outputs = smod(input_tensors, {}).toTuple()->elements();
|
||||
ASSERT_TRUE(outputs.size() > 0);
|
||||
at::Tensor output_2 = outputs[0].toTensor();
|
||||
smod.runtime().check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
}
|
||||
|
|
@ -947,9 +949,11 @@ TEST(StaticRuntime, CleanUpMemory) {
|
|||
auto output_1 = getTensor(mod.forward(inputs));
|
||||
|
||||
// run static runtime
|
||||
std::vector<at::Tensor> input_tensors(
|
||||
std::vector<c10::IValue> input_tensors(
|
||||
{ad_emb_packed, user_emb, wide});
|
||||
at::Tensor output_2 = runtime(input_tensors)[0];
|
||||
auto outputs = runtime(input_tensors, {}).toTuple()->elements();
|
||||
ASSERT_TRUE(outputs.size() > 0);
|
||||
auto output_2 = outputs[0].toTensor();
|
||||
runtime.check_for_memory_leak();
|
||||
EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
|
||||
if (manage_output_tensors) {
|
||||
|
|
@ -1053,9 +1057,9 @@ TEST(StaticRuntime, ManageOutputTensorsWithDeallocateOutputTensors) {
|
|||
torch::randn({batch_size, 1, embedding_size});
|
||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||
auto wide = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> input_tensors(
|
||||
std::vector<c10::IValue> input_tensors(
|
||||
{ad_emb_packed, user_emb, wide});
|
||||
runtime(input_tensors)[0];
|
||||
runtime(input_tensors, {});
|
||||
runtime.check_for_memory_leak();
|
||||
runtime.deallocateOutputTensors();
|
||||
runtime.checkOutputTensorMemoryLeaks();
|
||||
|
|
@ -1079,21 +1083,21 @@ TEST(StaticRuntime, ManageOutputTensorsWithoutDeallocateOutputTensors) {
|
|||
torch::randn({batch_size, 1, embedding_size});
|
||||
auto user_emb = torch::randn({batch_size, 1, embedding_size});
|
||||
auto wide = torch::randn({batch_size, num_features});
|
||||
std::vector<at::Tensor> input_tensors(
|
||||
std::vector<c10::IValue> input_tensors(
|
||||
{ad_emb_packed, user_emb, wide});
|
||||
// Profile run.
|
||||
runtime(input_tensors)[0];
|
||||
runtime(input_tensors, {});
|
||||
runtime.deallocateOutputTensors();
|
||||
// Run again to allocate output Tensors without deallocating them.
|
||||
runtime(input_tensors)[0];
|
||||
runtime(input_tensors, {});
|
||||
// Memory leak checking fails.
|
||||
EXPECT_THROW(runtime.checkOutputTensorMemoryLeaks(), std::exception);
|
||||
// Calling the runtime without deallocation fails too.
|
||||
EXPECT_THROW(runtime(input_tensors)[0], std::exception);
|
||||
EXPECT_THROW(runtime(input_tensors, {}), std::exception);
|
||||
// After deallocation, everything works fine.
|
||||
runtime.deallocateOutputTensors();
|
||||
runtime.checkOutputTensorMemoryLeaks();
|
||||
runtime(input_tensors)[0];
|
||||
runtime(input_tensors, {});
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, FusionPass) {
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@
|
|||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
#include <torch/csrc/jit/runtime/profiling_record.h>
|
||||
#include <torch/csrc/jit/runtime/symbolic_script.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
|
|
@ -58,6 +59,8 @@
|
|||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ThreadLocalDebugInfo.h>
|
||||
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <functional>
|
||||
|
|
@ -67,6 +70,7 @@
|
|||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
|
@ -1803,6 +1807,46 @@ TEST(LoopPeelerTest, SimpleNestedLoops2) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(JitTracing, Basic) {
|
||||
constexpr int batch_size = 4;
|
||||
constexpr int input_size = 256;
|
||||
|
||||
int hidden_size = 2 * input_size;
|
||||
|
||||
auto input = at::randn({batch_size, input_size}, at::kCPU);
|
||||
auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
|
||||
auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
|
||||
auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
|
||||
auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
|
||||
|
||||
auto graph = build_lstm();
|
||||
auto stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||
auto traced = TraceGraph(graph, stack);
|
||||
Tensor prof_out;
|
||||
pop(stack, prof_out);
|
||||
|
||||
{
|
||||
stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||
Code cd(traced, "traced");
|
||||
InterpreterState is{cd};
|
||||
is.run(stack);
|
||||
Tensor traced_out;
|
||||
pop(stack, traced_out);
|
||||
torch::allclose(prof_out, traced_out);
|
||||
}
|
||||
|
||||
{
|
||||
stack = createStack({input, hx, cx, w_ih, w_hh});
|
||||
Code cd(graph, "graph");
|
||||
InterpreterState is{cd};
|
||||
is.run(stack);
|
||||
Tensor scripted_out;
|
||||
pop(stack, scripted_out);
|
||||
torch::allclose(prof_out, scripted_out);
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
|
||||
static const auto basic_example = R"JIT(
|
||||
def basic(x, y):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import io
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import gc
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import shutil
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
|
@ -1289,6 +1291,101 @@ class TestFreezing(JitTestCase):
|
|||
FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
|
||||
FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)
|
||||
|
||||
def test_freeze_module_with_user_preserved_attribute_on_submodule(self):
|
||||
class SubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule, self).__init__()
|
||||
self.a = 1
|
||||
self.b = 2
|
||||
|
||||
def forward(self):
|
||||
return self.a + self.b
|
||||
|
||||
class Module(nn.Module):
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.sub1 = SubModule()
|
||||
self.sub2 = SubModule()
|
||||
|
||||
def forward(self):
|
||||
return self.sub1() + self.sub2()
|
||||
|
||||
m = torch.jit.script(Module())
|
||||
m.eval()
|
||||
m = torch.jit.freeze(m, preserved_attrs=['sub1.a', 'sub2.a'])
|
||||
fm = m._c
|
||||
|
||||
self.assertTrue(fm.hasattr('sub1'))
|
||||
self.assertTrue(fm.sub1.hasattr('a'))
|
||||
self.assertFalse(fm.sub1.hasattr('b'))
|
||||
self.assertTrue(fm.hasattr('sub2'))
|
||||
self.assertTrue(fm.sub2.hasattr('a'))
|
||||
self.assertFalse(fm.sub2.hasattr('b'))
|
||||
self.assertEqual(m(), 6)
|
||||
m.sub1.a += 1
|
||||
self.assertEqual(m(), 7)
|
||||
|
||||
def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self):
|
||||
class SubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule, self).__init__()
|
||||
self.a = 1
|
||||
self.b = 2
|
||||
|
||||
def forward(self):
|
||||
return self.a + self.b
|
||||
|
||||
@torch.jit.export
|
||||
def method_a(self):
|
||||
return 42
|
||||
|
||||
class Module(nn.Module):
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.sub = SubModule()
|
||||
|
||||
def forward(self):
|
||||
return 1
|
||||
|
||||
m = torch.jit.script(Module())
|
||||
m.eval()
|
||||
fm = torch.jit.freeze(m, preserved_attrs=['sub.a', 'sub.method_a'])._c
|
||||
|
||||
self.assertTrue(fm.hasattr('sub'))
|
||||
self.assertTrue(fm.sub.hasattr('a'))
|
||||
self.assertFalse(fm.sub.hasattr('b'))
|
||||
self.assertTrue(fm.sub._has_method('method_a'))
|
||||
|
||||
def test_freeze_module_with_user_preserved_method_on_submodule(self):
|
||||
class SubModule(nn.Module):
|
||||
def __init__(self):
|
||||
super(SubModule, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return self.method_a(x) + self.method_b(x)
|
||||
|
||||
def method_a(self, x):
|
||||
return x * x
|
||||
|
||||
def method_b(self, x):
|
||||
return x + x
|
||||
|
||||
class Module(nn.Module):
|
||||
def __init__(self):
|
||||
super(Module, self).__init__()
|
||||
self.sub = SubModule()
|
||||
|
||||
def forward(self, x):
|
||||
return self.sub(x)
|
||||
|
||||
m = torch.jit.script(Module())
|
||||
m.eval()
|
||||
fm = torch.jit.freeze(m, preserved_attrs=['sub.method_a'])._c
|
||||
|
||||
self.assertTrue(fm.hasattr('sub'))
|
||||
self.assertTrue(fm.sub._has_method('method_a'))
|
||||
self.assertFalse(fm.sub._has_method('method_b'))
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_module_with_shared_type_instances(self):
|
||||
class Child(nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
import torch
|
||||
import torch._C
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import inspect
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.testing._internal.jit_utils import JitTestCase, make_global
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
# flake8: noqa
|
||||
# TODO: enable linting check for this file
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: mobile"]
|
||||
|
||||
import torch
|
||||
import torch._C
|
||||
import torch.backends.xnnpack
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA, _inline_everything
|
||||
from torch import nn
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from itertools import product as product
|
||||
from typing import NamedTuple, Optional
|
||||
import io
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase, execWrapper
|
||||
import operator
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import unittest
|
||||
import io
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import io
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
|
|
|||
|
|
@ -1481,6 +1481,7 @@ class TestNormalizeOperators(JitTestCase):
|
|||
"index_put",
|
||||
"nn.functional.conv2d",
|
||||
"nn.functional.dropout",
|
||||
"nn.functional.embedding", # Implemented with a lambda
|
||||
"polygamma",
|
||||
"special.polygamma",
|
||||
"repeat",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import torch
|
||||
|
||||
# This is how we include tests located in test/jit/...
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import random
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import sys
|
||||
import os
|
||||
import contextlib
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import sys
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit_executor=legacy")
|
||||
from test_jit_fuser import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit_executor=legacy")
|
||||
from test_jit import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit_executor=profiling")
|
||||
from test_jit import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
import sys
|
||||
sys.argv.append("--jit_executor=simple")
|
||||
from test_jit import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
# Owner(s): ["oncall: jit"]
|
||||
|
||||
from test_jit import JitTestCase
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
|
|
|
|||
|
|
@ -13982,6 +13982,32 @@ class TestNNDeviceType(NNTestCase):
|
|||
self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
|
||||
self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))
|
||||
|
||||
def test_conv_empty_channel(self, device):
|
||||
in_channels = 0
|
||||
mod = torch.nn.Conv1d(in_channels, 8, 2, stride=2).to(device)
|
||||
inp = torch.randn(2, 0, 15, device=device)
|
||||
self._test_module_empty_input(mod, inp, check_size=False)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||
inp = torch.randn(2, 1, 0, device=device)
|
||||
mod(inp)
|
||||
|
||||
mod = torch.nn.Conv2d(in_channels, 33, 3, stride=2).to(device)
|
||||
inp = torch.randn(2, 0, 50, 100, device=device)
|
||||
self._test_module_empty_input(mod, inp, check_size=False)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||
inp = torch.randn(2, 1, 40, 0, device=device)
|
||||
mod(inp)
|
||||
|
||||
mod = torch.nn.Conv3d(in_channels, 33, 3, stride=2).to(device)
|
||||
inp = torch.randn(2, 0, 50, 20, 40, device=device)
|
||||
self._test_module_empty_input(mod, inp, check_size=False)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Given groups=1, weight"):
|
||||
inp = torch.randn(2, 1, 50, 0, 40, device=device)
|
||||
mod(inp)
|
||||
|
||||
def test_group_conv_empty(self, device):
|
||||
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
|
||||
inp = torch.randn(0, 4, 4, 4, device=device)
|
||||
|
|
|
|||
|
|
@ -247,8 +247,8 @@ class TestOptim(TestCase):
|
|||
|
||||
def _test_complex_optimizer(self, optimizer_constructor):
|
||||
complex_param = torch.randn(5, 5, dtype=torch.complex64, requires_grad=True)
|
||||
complex_opt = optimizer_constructor(complex_param)
|
||||
real_param = torch.view_as_real(complex_param).detach().clone().requires_grad_()
|
||||
complex_opt = optimizer_constructor(complex_param)
|
||||
real_opt = optimizer_constructor(real_param)
|
||||
|
||||
for i in range(3):
|
||||
|
|
@ -652,11 +652,6 @@ class TestOptim(TestCase):
|
|||
[param], lr=1e-1, initial_accumulator_value=0.1
|
||||
)
|
||||
)
|
||||
self._test_complex_optimizer(
|
||||
lambda param: optimizer(
|
||||
[param], lr=1e-1, initial_accumulator_value=0.1, weight_decay=1
|
||||
)
|
||||
)
|
||||
|
||||
def test_adamax(self):
|
||||
for optimizer in [optim.Adamax, optim_mt.Adamax]:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class StaticModule:
|
|||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not kwargs:
|
||||
return self.static_module(args)
|
||||
return self.static_module(args, {})
|
||||
else:
|
||||
return self.static_module(args, kwargs)
|
||||
|
||||
|
|
@ -227,20 +227,20 @@ class TestStaticModule(TestCase):
|
|||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
acc_bot = bot_l_acc(bot_inp)
|
||||
torch.testing.assert_close(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
acc_top = top_l_acc(top_inp)
|
||||
torch.testing.assert_close(acc_top, ref_top)
|
||||
for _ in range(5):
|
||||
with torch.no_grad():
|
||||
bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512])
|
||||
top_inp = torch.randn(2048, 100) # torch.Size([2048, 100])
|
||||
ref_bot = bot_l(bot_inp)
|
||||
acc_bot = bot_l_acc(bot_inp)[0]
|
||||
acc_bot = bot_l_acc(bot_inp)
|
||||
torch.testing.assert_close(acc_bot, ref_bot)
|
||||
ref_top = top_l(top_inp)
|
||||
acc_top = top_l_acc(top_inp)[0]
|
||||
acc_top = top_l_acc(top_inp)
|
||||
torch.testing.assert_close(acc_top, ref_top)
|
||||
|
||||
def test_trivial_graph(self):
|
||||
|
|
@ -248,7 +248,7 @@ class TestStaticModule(TestCase):
|
|||
tg = torch.jit.script(trivial_graph)
|
||||
o_ref = tg(s, s, s)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s, s, s)[0]
|
||||
o_test = tg_a(s, s, s)
|
||||
torch.testing.assert_close(o_ref, o_test)
|
||||
|
||||
def test_leaky_relu(self):
|
||||
|
|
@ -256,7 +256,7 @@ class TestStaticModule(TestCase):
|
|||
tg = torch.jit.script(nn.LeakyReLU(0.1))
|
||||
o_ref = tg(s)
|
||||
tg_a = StaticModule(tg)
|
||||
o_test = tg_a(s)[0]
|
||||
o_test = tg_a(s)
|
||||
torch.testing.assert_close(o_ref, o_test)
|
||||
|
||||
def test_attr(self):
|
||||
|
|
@ -292,7 +292,7 @@ class TestStaticModule(TestCase):
|
|||
|
||||
ms = torch.jit.script(m)
|
||||
sm = StaticModule(ms)
|
||||
output_sm = sm(input)[0]
|
||||
output_sm = sm(input)
|
||||
torch.testing.assert_close(output_s, output_sm)
|
||||
sm.benchmark([input], {}, 2, 2)
|
||||
sm.benchmark_individual_ops([input], {}, 2, 2)
|
||||
|
|
|
|||
|
|
@ -285,6 +285,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/runtime/script_profile.cpp",
|
||||
"torch/csrc/jit/runtime/symbolic_script.cpp",
|
||||
"torch/csrc/jit/runtime/symbolic_shape_registry.cpp",
|
||||
"torch/csrc/jit/runtime/jit_trace.cpp",
|
||||
"torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
|
||||
"torch/csrc/jit/serialization/import.cpp",
|
||||
"torch/csrc/jit/serialization/import_export_helpers.cpp",
|
||||
|
|
|
|||
|
|
@ -639,27 +639,28 @@ def convert_fx(
|
|||
|
||||
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
* `qconfig_dict`: qconfig_dict with the either
|
||||
a) same keys as what is passed to the qconfig_dict in prepare_fx API, with same values or `None`.
|
||||
b) additional keys with values set to `None`
|
||||
For each entry whose value is set to None, we skip quantizing that entry in the model.
|
||||
Example:
|
||||
* `qconfig_dict`: qconfig_dict with either same keys as what is passed to
|
||||
the qconfig_dict in `prepare_fx` API, with same values or `None`, or
|
||||
additional keys with values set to `None`
|
||||
|
||||
For each entry whose value is set to None, we skip quantizing that entry in the model::
|
||||
|
||||
qconfig_dict = {
|
||||
# used for object_type, skip quantizing torch.nn.functional.add
|
||||
"object_type": [
|
||||
(torch.nn.functional.add, None),
|
||||
(torch.nn.functional.linear, qconfig_from_prepare)
|
||||
...,
|
||||
],
|
||||
|
||||
# used for object_type, skip quantizing torch.nn.functional.add
|
||||
"object_type": [
|
||||
(torch.nn.functional.add, None),
|
||||
(torch.nn.functional.linear, qconfig_from_prepare)
|
||||
...,
|
||||
],
|
||||
|
||||
# sed for module names, skip quantizing "foo.bar"
|
||||
"module_name": [
|
||||
("foo.bar", None)
|
||||
...,
|
||||
],
|
||||
# sed for module names, skip quantizing "foo.bar"
|
||||
"module_name": [
|
||||
("foo.bar", None)
|
||||
...,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
Return:
|
||||
A quantized model (GraphModule)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,10 @@ namespace deploy {
|
|||
static const size_t NUM_FROZEN_PY_BUILTIN_MODULES = 6;
|
||||
static const size_t NUM_FROZEN_PY_STDLIB_MODULES = 680;
|
||||
|
||||
extern "C" struct _frozen _PyImport_FrozenModules[];
|
||||
extern "C" struct _frozen _PyImport_FrozenModules_torch[];
|
||||
extern "C" PyObject* initModule(void);
|
||||
|
||||
REGISTER_TORCH_DEPLOY_BUILTIN(cpython_internal, PyImport_FrozenModules);
|
||||
REGISTER_TORCH_DEPLOY_BUILTIN(frozenpython, _PyImport_FrozenModules);
|
||||
REGISTER_TORCH_DEPLOY_BUILTIN(
|
||||
frozentorch,
|
||||
_PyImport_FrozenModules_torch,
|
||||
|
|
|
|||
|
|
@ -34,78 +34,6 @@ using namespace py::literals;
|
|||
#define PYOBJ_ASSERT(obj) assert(NULL != obj);
|
||||
#endif
|
||||
|
||||
#define FOREACH_LIBRARY(_) \
|
||||
_(array) \
|
||||
_(_asyncio) \
|
||||
_(audioop) \
|
||||
_(binascii) \
|
||||
_(_bisect) \
|
||||
_(_blake2) \
|
||||
_(_bz2) \
|
||||
_(cmath) \
|
||||
_(_codecs_cn) \
|
||||
_(_codecs_hk) \
|
||||
_(_codecs_iso2022) \
|
||||
_(_codecs_jp) \
|
||||
_(_codecs_kr) \
|
||||
_(_codecs_tw) \
|
||||
_(_contextvars) \
|
||||
_(_crypt) \
|
||||
_(_csv) \
|
||||
_(_ctypes) \
|
||||
_(_ctypes_test) \
|
||||
_(_curses) \
|
||||
_(_curses_panel) \
|
||||
_(_datetime) \
|
||||
_(_decimal) \
|
||||
_(_elementtree) \
|
||||
_(fcntl) \
|
||||
_(grp) \
|
||||
_(_hashlib) \
|
||||
_(_heapq) \
|
||||
_(_json) \
|
||||
_(_lsprof) \
|
||||
_(_lzma) \
|
||||
_(math) \
|
||||
_(_md5) \
|
||||
_(mmap) \
|
||||
_(_multibytecodec) \
|
||||
_(_multiprocessing) \
|
||||
_(nis) \
|
||||
_(_opcode) \
|
||||
_(ossaudiodev) \
|
||||
_(parser) \
|
||||
_(_pickle) \
|
||||
_(_posixsubprocess) \
|
||||
_(pyexpat) \
|
||||
_(_queue) \
|
||||
_(_random) \
|
||||
_(readline) \
|
||||
_(resource) \
|
||||
_(select) \
|
||||
_(_sha1) \
|
||||
_(_sha256) \
|
||||
_(_sha3) \
|
||||
_(_sha512) \
|
||||
_(_socket) \
|
||||
_(spwd) \
|
||||
_(_ssl) \
|
||||
_(_struct) \
|
||||
_(syslog) \
|
||||
_(termios) \
|
||||
_(_testbuffer) \
|
||||
_(_testcapi) \
|
||||
_(_testimportmultiple) \
|
||||
_(_testmultiphase) \
|
||||
_(unicodedata) \
|
||||
_(xxlimited) \
|
||||
_(_xxtestfuzz) \
|
||||
_(zlib)
|
||||
|
||||
#define DECLARE_LIBRARY_INIT(name) extern "C" PyObject* PyInit_##name(void);
|
||||
FOREACH_LIBRARY(DECLARE_LIBRARY_INIT)
|
||||
#undef DECLARE_LIBRARY_INIT
|
||||
|
||||
const char* start = R"PYTHON(
|
||||
import _ssl # must come before _hashlib otherwise ssl's locks will be set to a Python that might no longer exist...
|
||||
import sys
|
||||
|
|
@ -221,10 +149,6 @@ struct InitLockAcquire {
|
|||
struct __attribute__((visibility("hidden"))) ConcreteInterpreterImpl
|
||||
: public torch::deploy::InterpreterImpl {
|
||||
ConcreteInterpreterImpl() {
|
||||
#define APPEND_INIT(name) PyImport_AppendInittab(#name, PyInit_##name);
|
||||
FOREACH_LIBRARY(APPEND_INIT)
|
||||
#undef APPEND_INIT
|
||||
|
||||
BuiltinRegistry::runPreInitialization();
|
||||
|
||||
PyPreConfig preconfig;
|
||||
|
|
|
|||
82
torch/csrc/deploy/interpreter/register_frozenpython.cpp
Normal file
82
torch/csrc/deploy/interpreter/register_frozenpython.cpp
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#include <Python.h>
|
||||
#include <torch/csrc/deploy/interpreter/builtin_registry.h>
|
||||
|
||||
#define FOREACH_LIBRARY(_) \
|
||||
_(array) \
|
||||
_(_asyncio) \
|
||||
_(audioop) \
|
||||
_(binascii) \
|
||||
_(_bisect) \
|
||||
_(_blake2) \
|
||||
_(_bz2) \
|
||||
_(cmath) \
|
||||
_(_codecs_cn) \
|
||||
_(_codecs_hk) \
|
||||
_(_codecs_iso2022) \
|
||||
_(_codecs_jp) \
|
||||
_(_codecs_kr) \
|
||||
_(_codecs_tw) \
|
||||
_(_contextvars) \
|
||||
_(_crypt) \
|
||||
_(_csv) \
|
||||
_(_ctypes) \
|
||||
_(_ctypes_test) \
|
||||
_(_curses) \
|
||||
_(_curses_panel) \
|
||||
_(_datetime) \
|
||||
_(_decimal) \
|
||||
_(_elementtree) \
|
||||
_(fcntl) \
|
||||
_(grp) \
|
||||
_(_hashlib) \
|
||||
_(_heapq) \
|
||||
_(_json) \
|
||||
_(_lsprof) \
|
||||
_(_lzma) \
|
||||
_(math) \
|
||||
_(_md5) \
|
||||
_(mmap) \
|
||||
_(_multibytecodec) \
|
||||
_(_multiprocessing) \
|
||||
_(nis) \
|
||||
_(_opcode) \
|
||||
_(ossaudiodev) \
|
||||
_(parser) \
|
||||
_(_pickle) \
|
||||
_(_posixsubprocess) \
|
||||
_(pyexpat) \
|
||||
_(_queue) \
|
||||
_(_random) \
|
||||
_(readline) \
|
||||
_(resource) \
|
||||
_(select) \
|
||||
_(_sha1) \
|
||||
_(_sha256) \
|
||||
_(_sha3) \
|
||||
_(_sha512) \
|
||||
_(_socket) \
|
||||
_(spwd) \
|
||||
_(_ssl) \
|
||||
_(_struct) \
|
||||
_(syslog) \
|
||||
_(termios) \
|
||||
_(_testbuffer) \
|
||||
_(_testcapi) \
|
||||
_(_testimportmultiple) \
|
||||
_(_testmultiphase) \
|
||||
_(unicodedata) \
|
||||
_(xxlimited) \
|
||||
_(_xxtestfuzz) \
|
||||
_(zlib)
|
||||
|
||||
#define DECLARE_LIBRARY_INIT(name) extern "C" PyObject* PyInit_##name(void);
|
||||
FOREACH_LIBRARY(DECLARE_LIBRARY_INIT)
|
||||
#undef DECLARE_LIBRARY_INIT
|
||||
|
||||
extern "C" struct _frozen _PyImport_FrozenModules[];
|
||||
|
||||
#define STD_LIBARY_PARMS(name) , #name, PyInit_##name
|
||||
REGISTER_TORCH_DEPLOY_BUILTIN(
|
||||
frozenpython,
|
||||
_PyImport_FrozenModules FOREACH_LIBRARY(STD_LIBARY_PARMS));
|
||||
#undef STD_LIBARY_PARMS
|
||||
29
torch/csrc/jit/mobile/code.h
Normal file
29
torch/csrc/jit/mobile/code.h
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/operator_name.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using DebugHandle = int64_t;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
struct Code {
|
||||
std::vector<Instruction> instructions_;
|
||||
std::vector<DebugHandle> debug_handles_;
|
||||
std::vector<c10::OperatorName> op_names_;
|
||||
std::vector<std::function<void(Stack&)>> operators_;
|
||||
std::vector<c10::IValue> constants_;
|
||||
std::vector<c10::TypePtr> types_;
|
||||
size_t register_size_; // Aggregated output size.
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
53
torch/csrc/jit/mobile/frame.h
Normal file
53
torch/csrc/jit/mobile/frame.h
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <torch/csrc/jit/mobile/code.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
|
||||
class Frame {
|
||||
public:
|
||||
explicit Frame(const Code& code) : code_(code) {}
|
||||
const Code& getCode() const {
|
||||
return code_;
|
||||
}
|
||||
|
||||
void step() {
|
||||
pc_++;
|
||||
}
|
||||
|
||||
void jump(size_t n) {
|
||||
pc_ += n;
|
||||
}
|
||||
|
||||
size_t getPC() const {
|
||||
return pc_;
|
||||
}
|
||||
|
||||
const Instruction& getInstruction() const {
|
||||
return code_.instructions_.at(pc_);
|
||||
}
|
||||
|
||||
c10::optional<int64_t> getDebugHandle() const {
|
||||
return getDebugHandle(pc_);
|
||||
}
|
||||
|
||||
c10::optional<int64_t> getDebugHandle(size_t pc) const {
|
||||
if (pc >= code_.debug_handles_.size()) {
|
||||
return {};
|
||||
}
|
||||
return code_.debug_handles_[pc];
|
||||
}
|
||||
|
||||
private:
|
||||
const Code& code_;
|
||||
size_t pc_{0};
|
||||
};
|
||||
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -167,7 +167,7 @@ bool Function::run(Stack& stack) const {
|
|||
schema->checkAndNormalizeInputs(
|
||||
stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
|
||||
}
|
||||
InterpreterState interp_state(code_);
|
||||
InterpreterState interp_state(*code_);
|
||||
return interp_state.run(stack);
|
||||
}
|
||||
|
||||
|
|
@ -181,8 +181,7 @@ const std::shared_ptr<Code> Function::get_code() const {
|
|||
}
|
||||
|
||||
int64_t Function::getExceptionDebugHandle() const {
|
||||
size_t pc = getInterpretersExceptionPC();
|
||||
return (pc < code_->debug_handles_.size()) ? code_->debug_handles_[pc] : -1;
|
||||
return getInterpretersExceptionDebugHandle();
|
||||
}
|
||||
|
||||
} // namespace mobile
|
||||
|
|
|
|||
|
|
@ -8,23 +8,22 @@
|
|||
#include <torch/csrc/jit/runtime/vararg_functions.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
|
||||
#include <torch/csrc/jit/backends/backend_exception.h>
|
||||
#include <torch/csrc/jit/mobile/observer.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
char const* toString(OpCode op);
|
||||
std::ostream& operator<<(std::ostream& out, Instruction inst);
|
||||
namespace mobile {
|
||||
InterpreterState::InterpreterState(std::shared_ptr<Code> code)
|
||||
: code_(std::move(code)) {
|
||||
registers_.resize(code_->register_size_);
|
||||
InterpreterState::InterpreterState(const Code& code) {
|
||||
enterFrame(code);
|
||||
}
|
||||
|
||||
namespace {
|
||||
static thread_local int64_t exception_pc_{-1};
|
||||
static thread_local DebugHandle exception_debug_handle_{-1};
|
||||
void createObject(Stack& stack, const at::ClassTypePtr& type) {
|
||||
auto userObj = c10::ivalue::Object::create(
|
||||
c10::StrongTypePtr(type->compilation_unit(), type),
|
||||
|
|
@ -46,21 +45,42 @@ void isinstance(Stack& stack, at::ArrayRef<at::TypePtr> types) {
|
|||
|
||||
using namespace at;
|
||||
|
||||
int64_t getInterpretersExceptionPC() {
|
||||
return exception_pc_;
|
||||
int64_t getInterpretersExceptionDebugHandle() {
|
||||
return exception_debug_handle_;
|
||||
}
|
||||
|
||||
void InterpreterState::enterFrame(const Code& code) {
|
||||
frames_.emplace_back(code);
|
||||
registers_.resize(registers_.size() + code.register_size_);
|
||||
}
|
||||
|
||||
void InterpreterState::leaveFrame() {
|
||||
registers_.resize(
|
||||
registers_.size() - frames_.back().getCode().register_size_);
|
||||
frames_.pop_back();
|
||||
}
|
||||
|
||||
void InterpreterState::saveExceptionDebugHandle() {
|
||||
const auto& frame = frames_.back();
|
||||
if (auto handle = frame.getDebugHandle()) {
|
||||
exception_debug_handle_ = *handle;
|
||||
}
|
||||
}
|
||||
|
||||
bool InterpreterState::run(Stack& stack) {
|
||||
size_t pc = 0;
|
||||
while (true) {
|
||||
try {
|
||||
Instruction inst = code_->instructions_.at(pc);
|
||||
|
||||
auto& frame = frames_.back();
|
||||
const auto& code = frame.getCode();
|
||||
const auto pc = frame.getPC();
|
||||
auto inst = frame.getInstruction();
|
||||
// If no valid debug handle found then just log pc.
|
||||
// This is possible when we did not save debug handles
|
||||
DebugHandle debug_handle = pc >= code_->debug_handles_.size()
|
||||
? pc
|
||||
: code_->debug_handles_.at(pc);
|
||||
|
||||
DebugHandle debug_handle = pc;
|
||||
if (auto handle = frame.getDebugHandle()) {
|
||||
debug_handle = *handle;
|
||||
}
|
||||
|
||||
// std::cout << "RUNNING " << pc << " "
|
||||
// << code_->instructions_with_handles_[pc].instruction;
|
||||
|
|
@ -93,63 +113,63 @@ bool InterpreterState::run(Stack& stack) {
|
|||
}
|
||||
|
||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||
code_->op_names_[inst.X].name, debug_handle, stack);
|
||||
code_->operators_[inst.X](stack);
|
||||
++pc;
|
||||
code.op_names_[inst.X].name, debug_handle, stack);
|
||||
code.operators_[inst.X](stack);
|
||||
frame.step();
|
||||
} break;
|
||||
case OPN: {
|
||||
stack.push_back(inst.N);
|
||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||
code_->op_names_[inst.X].name, debug_handle, stack);
|
||||
code_->operators_[inst.X](stack);
|
||||
++pc;
|
||||
code.op_names_[inst.X].name, debug_handle, stack);
|
||||
code.operators_[inst.X](stack);
|
||||
frame.step();
|
||||
} break;
|
||||
case INTERFACE_CALL: {
|
||||
torch::jit::Function& method =
|
||||
peek(stack, 0, inst.N)
|
||||
.toObject()
|
||||
->type()
|
||||
->getMethod(code_->constants_[inst.X].toStringRef());
|
||||
->getMethod(code.constants_[inst.X].toStringRef());
|
||||
RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS(
|
||||
method.name(), debug_handle, stack);
|
||||
method.run(stack);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case LOAD:
|
||||
stack.emplace_back(reg(inst.X));
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case MOVE:
|
||||
stack.emplace_back(std::move(reg(inst.X)));
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case STORE:
|
||||
reg(inst.X) = pop(stack);
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case STOREN:
|
||||
for (size_t i = inst.N; i > 0; --i) {
|
||||
reg(inst.X + i - 1) = pop(stack);
|
||||
}
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case DROP:
|
||||
pop(stack);
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case DROPR:
|
||||
reg(inst.X) = IValue();
|
||||
++pc;
|
||||
frame.step();
|
||||
break;
|
||||
case LOADC:
|
||||
stack.emplace_back(code_->constants_[inst.X]);
|
||||
++pc;
|
||||
stack.emplace_back(code.constants_[inst.X]);
|
||||
frame.step();
|
||||
break;
|
||||
case GET_ATTR: {
|
||||
auto userObj = pop(stack).toObject();
|
||||
auto value = userObj->getSlot(inst.X);
|
||||
push(stack, std::move(value));
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case SET_ATTR: {
|
||||
auto v = pop(stack);
|
||||
|
|
@ -163,72 +183,76 @@ bool InterpreterState::run(Stack& stack) {
|
|||
userObj->type()->addAttribute(ss.str(), c10::NoneType::get());
|
||||
}
|
||||
userObj->setSlot(inst.X, std::move(v));
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case JF:
|
||||
pc += (pop(stack).toBool()) ? 1 : inst.X;
|
||||
frame.jump(pop(stack).toBool() ? 1 : inst.X);
|
||||
break;
|
||||
case JMP:
|
||||
pc += inst.X;
|
||||
frame.jump(inst.X);
|
||||
break;
|
||||
case LOOP: {
|
||||
// stack: iteration_count, max_iter, cond, loop_carried_deps...
|
||||
auto frame = stack.end() - (inst.N + 1);
|
||||
int64_t trip_count = frame[0].toInt();
|
||||
int64_t max_trip_count = frame[1].toInt();
|
||||
bool cond = frame[2].toBool();
|
||||
auto sframe = stack.end() - (inst.N + 1);
|
||||
int64_t trip_count = sframe[0].toInt();
|
||||
int64_t max_trip_count = sframe[1].toInt();
|
||||
bool cond = sframe[2].toBool();
|
||||
if (trip_count < max_trip_count && cond) {
|
||||
frame[2] = trip_count;
|
||||
frame[0] = trip_count + 1;
|
||||
++pc;
|
||||
sframe[2] = trip_count;
|
||||
sframe[0] = trip_count + 1;
|
||||
frame.step();
|
||||
} else {
|
||||
size_t n_loop_carried = inst.N - 2;
|
||||
for (const auto i : c10::irange(n_loop_carried)) {
|
||||
frame[i] = std::move(frame[i + 3]);
|
||||
sframe[i] = std::move(sframe[i + 3]);
|
||||
}
|
||||
drop(stack, 3); // iteration_count, max_iter, cond
|
||||
pc += inst.X;
|
||||
frame.jump(inst.X);
|
||||
}
|
||||
} break;
|
||||
case RET:
|
||||
leaveFrame();
|
||||
if (frames_.size() > 0) {
|
||||
continue;
|
||||
}
|
||||
return false;
|
||||
case LIST_CONSTRUCT: {
|
||||
const auto& type = code_->types_[inst.X]->expectRef<at::ListType>();
|
||||
const auto& type = code.types_[inst.X]->expectRef<at::ListType>();
|
||||
listConstruct(stack, type, inst.N);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case LIST_UNPACK: {
|
||||
listUnpack(stack, inst.X);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case TUPLE_CONSTRUCT: {
|
||||
tupleConstruct(stack, inst.X);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case TUPLE_SLICE: {
|
||||
tupleSlice(stack, inst.X, inst.X + inst.N);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case DICT_CONSTRUCT: {
|
||||
const auto& type = code_->types_[inst.X]->expectRef<at::DictType>();
|
||||
const auto& type = code.types_[inst.X]->expectRef<at::DictType>();
|
||||
dictConstruct(stack, type, inst.N);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case NAMED_TUPLE_CONSTRUCT: {
|
||||
namedTupleConstruct(
|
||||
stack, code_->types_[inst.X]->expect<at::TupleType>(), inst.N);
|
||||
++pc;
|
||||
stack, code.types_[inst.X]->expect<at::TupleType>(), inst.N);
|
||||
frame.step();
|
||||
} break;
|
||||
case CREATE_OBJECT: {
|
||||
auto type = code_->types_[inst.X]->expect<c10::ClassType>();
|
||||
auto type = code.types_[inst.X]->expect<c10::ClassType>();
|
||||
createObject(stack, type);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case ISINSTANCE: {
|
||||
at::ArrayRef<TypePtr> types(
|
||||
&(code_->types_[inst.X]), &(code_->types_[inst.X + inst.N]));
|
||||
&(code.types_[inst.X]), &(code.types_[inst.X + inst.N]));
|
||||
isinstance(stack, types);
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
case WARN: {
|
||||
drop(stack, 1);
|
||||
|
|
@ -240,7 +264,7 @@ bool InterpreterState::run(Stack& stack) {
|
|||
const auto& sref = stack.back().toStringRef();
|
||||
TORCH_WARN(sref);
|
||||
stack.pop_back();
|
||||
++pc;
|
||||
frame.step();
|
||||
} break;
|
||||
default:
|
||||
AT_ERROR(toString(inst.op), " is invalid.");
|
||||
|
|
@ -251,15 +275,15 @@ bool InterpreterState::run(Stack& stack) {
|
|||
}
|
||||
// This exception must be caught first as it derived from c10::Error
|
||||
} catch (c10::BackendRuntimeException& e) {
|
||||
exception_pc_ = pc;
|
||||
saveExceptionDebugHandle();
|
||||
TORCH_RETHROW(e);
|
||||
} catch (c10::Error& error) {
|
||||
// Reason for catching and rethrowing the error is so that we can
|
||||
// set the exception pc that is queried later
|
||||
exception_pc_ = pc;
|
||||
saveExceptionDebugHandle();
|
||||
TORCH_RETHROW(error);
|
||||
} catch (...) {
|
||||
exception_pc_ = pc;
|
||||
saveExceptionDebugHandle();
|
||||
throw;
|
||||
}
|
||||
// for (auto val : stack) {
|
||||
|
|
|
|||
|
|
@ -1,35 +1,26 @@
|
|||
#pragma once
|
||||
#include <ATen/core/ivalue.h>
|
||||
#include <ATen/core/operator_name.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/jit/mobile/code.h>
|
||||
#include <torch/csrc/jit/mobile/frame.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
namespace mobile {
|
||||
using Stack = std::vector<c10::IValue>;
|
||||
using DebugHandle = int64_t;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
||||
struct Code {
|
||||
// TODO: Combine instructions and debug handles vector
|
||||
// into std::vector<<std::pair<Instruction, DebugHandle>>
|
||||
std::vector<Instruction> instructions_;
|
||||
std::vector<DebugHandle> debug_handles_;
|
||||
std::vector<c10::OperatorName> op_names_;
|
||||
std::vector<std::function<void(Stack&)>> operators_;
|
||||
std::vector<c10::IValue> constants_;
|
||||
std::vector<c10::TypePtr> types_;
|
||||
size_t register_size_; // Aggregated output size.
|
||||
};
|
||||
|
||||
struct InterpreterState {
|
||||
TORCH_API explicit InterpreterState(std::shared_ptr<Code> code);
|
||||
TORCH_API explicit InterpreterState(const Code& code);
|
||||
TORCH_API bool run(Stack& stack);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Code> code_;
|
||||
void enterFrame(const Code&);
|
||||
void leaveFrame();
|
||||
void saveExceptionDebugHandle();
|
||||
|
||||
c10::IValue& reg(size_t reg);
|
||||
std::vector<c10::IValue> registers_;
|
||||
std::vector<Frame> frames_;
|
||||
};
|
||||
|
||||
// Interpreter executes instruction in a loop one by one
|
||||
|
|
@ -39,7 +30,7 @@ struct InterpreterState {
|
|||
// Note that this is set only when exception occurs.
|
||||
// since this is a thread local variable and setting it for
|
||||
// every instruction will add overhead of thread local variable access.
|
||||
int64_t getInterpretersExceptionPC();
|
||||
DebugHandle getInterpretersExceptionDebugHandle();
|
||||
} // namespace mobile
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ void _not(Stack& stack) {
|
|||
void boolTensor(Stack& stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
push(stack, a.is_nonzero());
|
||||
push(stack, at::native::is_nonzero(a));
|
||||
}
|
||||
|
||||
void toList(Stack& stack) {
|
||||
|
|
|
|||
|
|
@ -17,6 +17,16 @@ namespace jit {
|
|||
|
||||
namespace {
|
||||
|
||||
std::vector<std::string> splitName(const std::string& name) {
|
||||
std::vector<std::string> result;
|
||||
std::string sub_name;
|
||||
std::istringstream name_stream(name);
|
||||
while (std::getline(name_stream, sub_name, '.')) {
|
||||
result.push_back(std::move(sub_name));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
class AttributePropagator {
|
||||
public:
|
||||
AttributePropagator(
|
||||
|
|
@ -27,27 +37,27 @@ class AttributePropagator {
|
|||
: module_(module),
|
||||
freezeInterfaces_(freezeInterfaces),
|
||||
preserveParameters_(preserveParameters) {
|
||||
// Currently only top level attributes and functions can be preserved
|
||||
// explicitly.
|
||||
auto checkName = [this](std::string& name) {
|
||||
if (module_.hasattr(name)) {
|
||||
auto attr = module_.attr(name);
|
||||
const auto resolved_name = resolveName(name);
|
||||
|
||||
// Freezing client wants to presever this submodule. When cleaning
|
||||
// the frozen module, make sure it will be preserved entirely.
|
||||
if (attr.isModule()) {
|
||||
preservedSubModule_.insert(attr.toModule()._ivalue());
|
||||
if (resolved_name) {
|
||||
const auto& parent_module = resolved_name->first;
|
||||
const auto& attr_name = resolved_name->second;
|
||||
if (parent_module.hasattr(attr_name)) {
|
||||
auto value = parent_module.attr(attr_name);
|
||||
// Freezing client wants to presever this submodule. When cleaning
|
||||
// the frozen module, make sure it will be preserved entirely.
|
||||
if (value.isModule()) {
|
||||
preservedSubModule_.insert(value.toModule()._ivalue());
|
||||
}
|
||||
insertMutableAttr(attr_name, value, parent_module._ivalue());
|
||||
} else {
|
||||
auto fn = parent_module.get_method(attr_name);
|
||||
preservedMethods_.insert(&fn.function());
|
||||
}
|
||||
insertMutableAttr(name, attr, module_._ivalue());
|
||||
return true;
|
||||
}
|
||||
|
||||
for (auto& fn : module_.type()->methods()) {
|
||||
if (fn->name() == name) {
|
||||
preservedMethods_.insert(fn);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
|
|
@ -119,6 +129,57 @@ class AttributePropagator {
|
|||
}
|
||||
|
||||
private:
|
||||
using ResolvedName = std::pair<Module, std::string>;
|
||||
|
||||
// Try to resolve qualified names (submodule1.submodule2.foo). If
|
||||
// the qualified name exists in the root module, return the unqualified
|
||||
// attribute/function name and the parent module. Else, return nullopt.
|
||||
// Examples:
|
||||
// submodule1.submodule2.foo -> {submodule2, "foo"}
|
||||
// submodule1.non_existent_module.foo -> nullopt
|
||||
c10::optional<ResolvedName> resolveName(const std::string& name) {
|
||||
auto sub_names = splitName(name);
|
||||
if (sub_names.empty()) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
auto& attr_name = sub_names.back();
|
||||
auto cur_module = module_;
|
||||
std::vector<ResolvedName> attr_infos;
|
||||
attr_infos.reserve(sub_names.size() - 1);
|
||||
|
||||
for (size_t i = 0; i < sub_names.size() - 1; ++i) {
|
||||
bool found = false;
|
||||
const auto& sub_name = sub_names[i];
|
||||
for (const auto& child_module : cur_module.named_children()) {
|
||||
if (child_module.name == sub_name) {
|
||||
attr_infos.emplace_back(cur_module._ivalue(), child_module.name);
|
||||
cur_module = child_module.value;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
return c10::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
if (cur_module.hasattr(attr_name) || cur_module.find_method(attr_name)) {
|
||||
// We don't want to mark these modules as mutable yet; that could
|
||||
// interfere with the inlining procedure. Instead, we'll record
|
||||
// the fact that the user wants to preserve them. They will be
|
||||
// processed during clean-up preparation (recordReferenceAttrs)
|
||||
for (auto& attr_info : attr_infos) {
|
||||
const auto& parent_module = attr_info.first;
|
||||
auto& sub_name = attr_info.second;
|
||||
userPreservedAttrs_[parent_module._ivalue()].insert(
|
||||
std::move(sub_name));
|
||||
}
|
||||
return std::make_pair(std::move(cur_module), std::move(attr_name));
|
||||
}
|
||||
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
// findConstantAttr function locates the sub Module where attributes are
|
||||
// defined. The algorithm chases getAttr chains to locate the submodules.
|
||||
// For example:
|
||||
|
|
@ -638,6 +699,16 @@ class AttributePropagator {
|
|||
}
|
||||
}
|
||||
}
|
||||
// We have to process the attributes that the user wants to preserve
|
||||
// separately since it's possible that the user-preserved module is
|
||||
// never referenced in the graph.
|
||||
for (const auto& attr_info : userPreservedAttrs_) {
|
||||
const auto& parent_module = attr_info.first;
|
||||
for (const auto& attr_name : attr_info.second) {
|
||||
const auto value = parent_module->getAttr(attr_name);
|
||||
insertMutableAttr(attr_name, value, parent_module);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This function recursively iterates over submodules to identify
|
||||
|
|
@ -710,7 +781,7 @@ class AttributePropagator {
|
|||
}
|
||||
}
|
||||
for (auto& fn : type->methods()) {
|
||||
if (preservedMethods_.count(fn) && *type == *module_.type()) {
|
||||
if (preservedMethods_.count(fn)) {
|
||||
continue;
|
||||
}
|
||||
funcsToRemove.push_back(fn);
|
||||
|
|
@ -774,6 +845,11 @@ class AttributePropagator {
|
|||
c10::intrusive_ptr<at::ivalue::Object>>
|
||||
object_memo_;
|
||||
|
||||
// Contains names of attributes that the user wants to preserve with
|
||||
// their owning modules.
|
||||
std::unordered_map<ModulePtr, std::unordered_set<std::string>>
|
||||
userPreservedAttrs_;
|
||||
|
||||
}; // class AttributePropagator
|
||||
|
||||
void checkModuleDoesNotReturnSelf(const Module& module) {
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@
|
|||
#include <torch/csrc/jit/runtime/autodiff.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/csrc/jit/runtime/jit_exception.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
#include <torch/csrc/jit/runtime/print_handler.h>
|
||||
#include <torch/csrc/jit/runtime/static/init.h>
|
||||
|
|
@ -520,6 +521,22 @@ void initJITBindings(PyObject* module) {
|
|||
},
|
||||
py::doc(
|
||||
"Interpret a JIT graph with given inputs without running any optimization passes on it"))
|
||||
.def(
|
||||
"_jit_trace_graph",
|
||||
[](std::shared_ptr<Graph>& graph, const py::tuple& inputs) {
|
||||
Stack stack;
|
||||
stack.reserve(inputs.size()); // captures?
|
||||
for (auto& obj : inputs) {
|
||||
stack.push_back(toTypeInferredIValue(obj));
|
||||
}
|
||||
auto g_inputs = graph->inputs();
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
if (stack[i].isTensor()) {
|
||||
g_inputs[i]->setType(stack[i].type());
|
||||
}
|
||||
}
|
||||
return TraceGraph(graph, stack);
|
||||
})
|
||||
.def("_jit_pass_remove_expands", RemoveExpands)
|
||||
.def("_jit_pass_erase_number_types", EraseNumberTypes)
|
||||
.def("_jit_pass_inline_fork_wait", InlineForkWait)
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user