mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[NFC] Fix some minor typos. (#145599)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145599 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
6cda572c98
commit
a989a0b13a
|
|
@ -1078,7 +1078,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
|
arg.get_dtype() if isinstance(arg, IRNode) else type(arg)
|
||||||
for arg in raw_args
|
for arg in raw_args
|
||||||
]
|
]
|
||||||
# Because generate_kernel_call can be overriden by a subclass, explictly call
|
# Because generate_kernel_call can be overriden by a subclass, explicitly call
|
||||||
# PythonWrapperCodegen.generate_kernel_call here
|
# PythonWrapperCodegen.generate_kernel_call here
|
||||||
PythonWrapperCodegen.generate_kernel_call(
|
PythonWrapperCodegen.generate_kernel_call(
|
||||||
self,
|
self,
|
||||||
|
|
@ -2376,7 +2376,7 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
|
|
||||||
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
||||||
# All inputs of hops must be explicitly passed in.
|
# All inputs of hops must be explicitly passed in.
|
||||||
# Free tensors and basic symbols should have been explictily lifted as inputs in dynamo.
|
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
|
||||||
assert len(outer_inputs) == len(
|
assert len(outer_inputs) == len(
|
||||||
subgraph.graph.graph_input_names
|
subgraph.graph.graph_input_names
|
||||||
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
||||||
|
|
|
||||||
|
|
@ -691,7 +691,7 @@ inline IValue toTypeInferredIValue(py::handle input) {
|
||||||
if (auto mod = as_module(object)) {
|
if (auto mod = as_module(object)) {
|
||||||
// if obj is already a ScriptModule, just return its ivalue
|
// if obj is already a ScriptModule, just return its ivalue
|
||||||
auto ptr = mod.value()._ivalue();
|
auto ptr = mod.value()._ivalue();
|
||||||
// explict copy semantics for strong ownership of the resource.
|
// explicit copy semantics for strong ownership of the resource.
|
||||||
return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
|
return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
|
||||||
ptr.release());
|
ptr.release());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1811,7 +1811,7 @@ def _shutdown_backend(pg):
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
if is_nccl_available() and isinstance(backend, ProcessGroupNCCL):
|
||||||
# explictly call shutdown to ensure that NCCL resources are released
|
# explicitly call shutdown to ensure that NCCL resources are released
|
||||||
backend._shutdown()
|
backend._shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -344,7 +344,7 @@ def stage_backward(
|
||||||
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
|
# 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors,
|
||||||
# and to itself (extract_tensors_with_grads) since it makes a recursive call
|
# and to itself (extract_tensors_with_grads) since it makes a recursive call
|
||||||
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
|
# 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad
|
||||||
# fix -> explictly pass in the ref to the fn, so there is no gc cycle anymore
|
# fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore
|
||||||
extract_tensors_with_grads(
|
extract_tensors_with_grads(
|
||||||
stage_output, output_grads, extract_tensors_with_grads
|
stage_output, output_grads, extract_tensors_with_grads
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,7 @@ def convert_arg_type_and_name( # type: ignore[return]
|
||||||
new_callsite_exprs,
|
new_callsite_exprs,
|
||||||
)
|
)
|
||||||
elif isinstance(typ, ListType):
|
elif isinstance(typ, ListType):
|
||||||
# Need to explictly pass the list as pointer + length
|
# Need to explicitly pass the list as pointer + length
|
||||||
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
|
c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
|
||||||
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
|
assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user