[BE][Eazy] remove torch.torch.xxx usages (#127800)

NB: `torch` is exposed in `torch/__init__.py`. So there can be `torch.torch.torch.xxx`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127800
Approved by: https://github.com/peterbell10, https://github.com/kit1980, https://github.com/malfet
This commit is contained in:
Xuehai Pan 2024-06-05 21:53:49 +00:00 committed by PyTorch MergeBot
parent 4123323eff
commit a7c596870d
5 changed files with 9 additions and 11 deletions

View File

@ -497,7 +497,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
with torch.cuda.amp.autocast(dtype=torch.float64):
c_float64 = torch.mm(a_float32, b_float32)
return c_float64
@ -796,7 +796,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.torch.float16)
self.assertEqual(exported.dtype, torch.float16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_arguments_binding(self):

View File

@ -380,10 +380,10 @@ class TestCuda(TestCase):
def check_workspace_size(inp):
torch._C._cuda_clearCublasWorkspaces()
start = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"]
start = torch.cuda.memory_stats()["active_bytes.all.allocated"]
with torch.no_grad():
torch.matmul(inp, inp)
finish = torch.torch.cuda.memory_stats()["active_bytes.all.allocated"]
finish = torch.cuda.memory_stats()["active_bytes.all.allocated"]
return finish - start
# check default

View File

@ -178,9 +178,7 @@ def preserve_global_state(fn):
finally:
cleanup.close()
torch._C._set_grad_enabled(prior_grad_mode)
torch.torch.autograd.grad_mode._enter_inference_mode(
prior_inference_mode
)
torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
torch.use_deterministic_algorithms(
prior_deterministic, warn_only=prior_warn_only
)

View File

@ -2421,7 +2421,7 @@ def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
ops.index_expr(
ModularIndexing(idx[dim] - start, 1, step), torch.int64
),
ops.constant(0, torch.torch.int64),
ops.constant(0, torch.int64),
)
)
assert mask