mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export][cond] support merging constant ints as unbacked symint (#152742)
@pianpwk points out that this will be helpful to address several data dependent issues in huggingface [models](e23705e557/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py (L332)) with the following pattern:
```python
idx = return 0 if u0 else return 1
return x[idx]
```
We could preserve the conditional with a cond.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152742
Approved by: https://github.com/zou3519
This commit is contained in:
parent
025c5cc048
commit
fc859077a0
|
|
@ -3134,14 +3134,14 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre
|
||||||
)
|
)
|
||||||
|
|
||||||
pred = torch.tensor(True)
|
pred = torch.tensor(True)
|
||||||
for pytree_in in [(1,), ("string",), (1.0,)]:
|
for pytree_in in [("string",), (1.0,)]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
||||||
):
|
):
|
||||||
fn(pred, pytree_in)
|
fn(pred, pytree_in)
|
||||||
|
|
||||||
for pytree_in in [(1,), ("string",), (1.0,)]:
|
for pytree_in in [("string",), (1.0,)]:
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
||||||
r"Cond doesn't work unless it is captured completely with torch.compile",
|
r"Cond doesn't work unless it is captured completely with torch.compile",
|
||||||
|
|
|
||||||
|
|
@ -1355,6 +1355,98 @@ graph():
|
||||||
M()(torch.randn(7))
|
M()(torch.randn(7))
|
||||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||||
|
|
||||||
|
def test_cond_branches_return_constant_int(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 1, tuple())
|
||||||
|
return x[idx]
|
||||||
|
|
||||||
|
args = (torch.randn(3, 3),)
|
||||||
|
m = M()
|
||||||
|
ep = export(M(), args)
|
||||||
|
if self._testMethodName == "test_cond_branches_return_constant_int":
|
||||||
|
self.assertExpectedInline(
|
||||||
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||||
|
"""\
|
||||||
|
class GraphModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x: "f32[3, 3]";
|
||||||
|
|
||||||
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||||
|
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||||
|
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||||
|
|
||||||
|
true_graph_0 = self.true_graph_0
|
||||||
|
false_graph_0 = self.false_graph_0
|
||||||
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
||||||
|
|
||||||
|
getitem_1: "Sym(u0)" = cond[0]; cond = None
|
||||||
|
|
||||||
|
ge_1: "Sym(u0 >= 0)" = getitem_1 >= 0
|
||||||
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None
|
||||||
|
le_1: "Sym(u0 <= 1)" = getitem_1 <= 1
|
||||||
|
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 1 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None
|
||||||
|
|
||||||
|
select: "f32[3]" = torch.ops.aten.select.int(x, 0, getitem_1); x = getitem_1 = None
|
||||||
|
return pytree.tree_unflatten((select,), self._out_spec)
|
||||||
|
|
||||||
|
class true_graph_0(torch.nn.Module):
|
||||||
|
def forward(self):
|
||||||
|
return (0,)
|
||||||
|
|
||||||
|
class false_graph_0(torch.nn.Module):
|
||||||
|
def forward(self):
|
||||||
|
return (1,)
|
||||||
|
""", # noqa: B950
|
||||||
|
)
|
||||||
|
self.assertEqual(m(*args), ep.module()(*args))
|
||||||
|
|
||||||
|
def test_cond_branches_return_same_int(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
idx = torch.cond(x.sum() > 3, lambda: 0, lambda: 0, tuple())
|
||||||
|
return x[idx]
|
||||||
|
|
||||||
|
args = (torch.randn(3, 3),)
|
||||||
|
m = M()
|
||||||
|
ep = export(M(), args)
|
||||||
|
# Ideally, we could remove the cond at the front end directly
|
||||||
|
# since it's not used anyway. But we can only do this early
|
||||||
|
# optimization if all the outputs are the same constants, which
|
||||||
|
# will complicates the output check so just keep it in the graph.
|
||||||
|
# let downstream to dce it.
|
||||||
|
if self._testMethodName == "test_cond_branches_return_same_int":
|
||||||
|
self.assertExpectedInline(
|
||||||
|
normalize_gm(ep.module().print_readable(print_output=False)),
|
||||||
|
"""\
|
||||||
|
class GraphModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
x: "f32[3, 3]";
|
||||||
|
|
||||||
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||||
|
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||||
|
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||||
|
|
||||||
|
true_graph_0 = self.true_graph_0
|
||||||
|
false_graph_0 = self.false_graph_0
|
||||||
|
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, ()); gt = true_graph_0 = false_graph_0 = None
|
||||||
|
getitem = cond[0]; cond = getitem = None
|
||||||
|
|
||||||
|
select: "f32[3]" = torch.ops.aten.select.int(x, 0, 0); x = None
|
||||||
|
return pytree.tree_unflatten((select,), self._out_spec)
|
||||||
|
|
||||||
|
class true_graph_0(torch.nn.Module):
|
||||||
|
def forward(self):
|
||||||
|
return (0,)
|
||||||
|
|
||||||
|
class false_graph_0(torch.nn.Module):
|
||||||
|
def forward(self):
|
||||||
|
return (0,)
|
||||||
|
""", # noqa: B950
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(m(*args), ep.module()(*args))
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_cond_contains_unbacked_no_escape(self):
|
def test_cond_contains_unbacked_no_escape(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -8328,10 +8328,10 @@ class GraphModule(torch.nn.Module):
|
||||||
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
_ = self._check_export_ret_graph_str(model, args, dynamic_shapes)
|
||||||
|
|
||||||
@skipIfTorchDynamo(
|
@skipIfTorchDynamo(
|
||||||
"Skip because _merge_tensors is not intended for dynamo to compile"
|
"Skip because _merge_output is not intended for dynamo to compile"
|
||||||
)
|
)
|
||||||
def test_merge_tensors(self):
|
def test_merge_output(self):
|
||||||
from torch._higher_order_ops.cond import _merge_tensors
|
from torch._higher_order_ops.cond import _merge_output
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
||||||
|
|
@ -8376,7 +8376,7 @@ class GraphModule(torch.nn.Module):
|
||||||
with fake_mode:
|
with fake_mode:
|
||||||
t1 = torch.empty_strided(size1, stride1)
|
t1 = torch.empty_strided(size1, stride1)
|
||||||
t2 = torch.empty_strided(size2, stride2)
|
t2 = torch.empty_strided(size2, stride2)
|
||||||
out = _merge_tensors(t1, t2, fake_mode)
|
out = _merge_output(t1, t2, fake_mode)
|
||||||
self.assertEqual(str(tuple(out.size())), merged_size)
|
self.assertEqual(str(tuple(out.size())), merged_size)
|
||||||
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
self.assertEqual(str(tuple(out.stride())), merged_stride)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1067,10 +1067,17 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||||
supports_aliasing=self.supports_aliasing,
|
supports_aliasing=self.supports_aliasing,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not only_consist_of(ret_val, (TensorVariable,)):
|
if not only_consist_of(ret_val, (TensorVariable, ConstantVariable)):
|
||||||
unimplemented(
|
unimplemented(
|
||||||
"Expected branches to return a possibly nested list/tuple/dict of tensors but it consists of non tensors.",
|
"Expected branches to return a possibly nested pytree of tensors "
|
||||||
|
"or constant ints but it consists of others.",
|
||||||
)
|
)
|
||||||
|
for ret in ret_val.unpack_var_sequence(tx):
|
||||||
|
if isinstance(ret, ConstantVariable) and ret.python_type() is not int:
|
||||||
|
unimplemented(
|
||||||
|
"Expected branches to return a possibly nested pytree of tensors "
|
||||||
|
f"or constant ints but it consists of others {ret.python_type()}.",
|
||||||
|
)
|
||||||
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
|
return ret_val, ret_treespec, ret_graph, ret_lifted_freevars
|
||||||
|
|
||||||
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
|
(true_r, true_treespec, true_graph, true_lifted_freevars) = speculate_branch(
|
||||||
|
|
|
||||||
|
|
@ -1641,11 +1641,13 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.module,
|
self.module,
|
||||||
self.serialized_name_to_node,
|
self.serialized_name_to_node,
|
||||||
self.serialized_name_to_meta,
|
self.serialized_name_to_meta,
|
||||||
|
self.unbacked_symbols
|
||||||
)
|
)
|
||||||
self.graph = torch.fx.Graph()
|
self.graph = torch.fx.Graph()
|
||||||
self.module = torch.nn.Module()
|
self.module = torch.nn.Module()
|
||||||
self.serialized_name_to_node = {}
|
self.serialized_name_to_node = {}
|
||||||
self.serialized_name_to_meta = {}
|
self.serialized_name_to_meta = {}
|
||||||
|
self.unbacked_symbols: set[sympy.Symbol] = set()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -1654,6 +1656,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.module,
|
self.module,
|
||||||
self.serialized_name_to_node,
|
self.serialized_name_to_node,
|
||||||
self.serialized_name_to_meta,
|
self.serialized_name_to_meta,
|
||||||
|
self.unbacked_symbols
|
||||||
) = saved
|
) = saved
|
||||||
|
|
||||||
def deserialize_extension_operator(self, serialized_target: str):
|
def deserialize_extension_operator(self, serialized_target: str):
|
||||||
|
|
@ -2184,7 +2187,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
self.symbol_name_to_range = {}
|
self.symbol_name_to_range = {}
|
||||||
# we also need to bump unbacked sym[float,int] counters in the
|
# we also need to bump unbacked sym[float,int] counters in the
|
||||||
# shape env to accommodate unbacked symbols in the exported program
|
# shape env to accommodate unbacked symbols in the exported program
|
||||||
self.unbacked_symbols: set[sympy.Symbol] = set()
|
self.unbacked_symbols = set()
|
||||||
count_unbacked_symfloat, count_unbacked_symint = -1, -1
|
count_unbacked_symfloat, count_unbacked_symint = -1, -1
|
||||||
unbacked_symfloat_prefix, unbacked_symint_prefix = (
|
unbacked_symfloat_prefix, unbacked_symint_prefix = (
|
||||||
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
|
prefix_str[t] for t in [SymT.UNBACKED_FLOAT, SymT.UNBACKED_INT]
|
||||||
|
|
@ -2422,27 +2425,34 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
# Check single value return
|
# Check single value return
|
||||||
if len(serialized_node.outputs) == 0:
|
if len(serialized_node.outputs) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
len(serialized_node.outputs) == 1
|
len(serialized_node.outputs) == 1
|
||||||
and serialized_node.outputs[0].type == "as_tensor"
|
and "torch.ops.higher_order" in serialized_node.target
|
||||||
|
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
||||||
):
|
):
|
||||||
# If it is a HOP node and it returns a tuple containing a single element
|
def _deserialize_hop_with_single_return(serialized_node, fx_node):
|
||||||
# we manually insert a getitem node to ensure the graph is consistent
|
|
||||||
# For BC, getattr() will return True if `is_single_tensor_return` doens't exist
|
|
||||||
# as prior to adding this field, it is guaranteed to have a single tensor return
|
|
||||||
# when the serialized_node has length=1 outputs and of type `as_tensor`.
|
|
||||||
if (
|
|
||||||
"torch.ops.higher_order" in serialized_node.target
|
|
||||||
and not getattr(serialized_node, "is_hop_single_tensor_return", True)
|
|
||||||
):
|
|
||||||
meta_val: list[Any] = []
|
meta_val: list[Any] = []
|
||||||
arg = serialized_node.outputs[0].as_tensor
|
arg = None
|
||||||
|
if serialized_node.outputs[0].type == "as_tensor":
|
||||||
|
arg = serialized_node.outputs[0].as_tensor
|
||||||
|
elif isinstance(serialized_node.outputs[0].value, (SymIntArgument, SymBoolArgument, SymFloatArgument)):
|
||||||
|
arg = serialized_node.outputs[0].value
|
||||||
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
deserialized_metadata = self.deserialize_metadata(serialized_node.metadata)
|
||||||
|
assert arg is not None
|
||||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||||
fx_node.meta["val"] = tuple(meta_val)
|
fx_node.meta["val"] = tuple(meta_val)
|
||||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||||
return
|
return
|
||||||
|
|
||||||
|
return _deserialize_hop_with_single_return(serialized_node, fx_node)
|
||||||
|
|
||||||
|
|
||||||
|
if (
|
||||||
|
len(serialized_node.outputs) == 1
|
||||||
|
and serialized_node.outputs[0].type == "as_tensor"
|
||||||
|
):
|
||||||
|
|
||||||
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
self.sync_fx_node(serialized_node.outputs[0].as_tensor.name, fx_node)
|
||||||
return
|
return
|
||||||
elif len(serialized_node.outputs) == 1 and isinstance(
|
elif len(serialized_node.outputs) == 1 and isinstance(
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ __all__ = [
|
||||||
"while_loop",
|
"while_loop",
|
||||||
"invoke_subgraph",
|
"invoke_subgraph",
|
||||||
"scan",
|
"scan",
|
||||||
|
"map",
|
||||||
"flex_attention",
|
"flex_attention",
|
||||||
"flex_attention_backward",
|
"flex_attention_backward",
|
||||||
"hints_wrapper",
|
"hints_wrapper",
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,9 @@ def cond(
|
||||||
false_fn (Callable): A callable function (a -> b) that is within the
|
false_fn (Callable): A callable function (a -> b) that is within the
|
||||||
scope that is being traced. The true branch and false branch must
|
scope that is being traced. The true branch and false branch must
|
||||||
have consistent input and outputs, meaning the inputs have to be
|
have consistent input and outputs, meaning the inputs have to be
|
||||||
the same, and the outputs have to be the same type and shape.
|
the same, and the outputs have to be the same type and shape. Int
|
||||||
|
output is also allowed. We'll make the output dynamic by turning it
|
||||||
|
into a symint.
|
||||||
|
|
||||||
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
|
operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the
|
||||||
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
|
true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to ().
|
||||||
|
|
@ -429,7 +431,7 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands):
|
||||||
|
|
||||||
merged_outs = []
|
merged_outs = []
|
||||||
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
for true_out, false_out in zip(flat_true_outs, flat_false_outs):
|
||||||
merged_outs.append(_merge_tensors(true_out, false_out, mode))
|
merged_outs.append(_merge_output(true_out, false_out, mode))
|
||||||
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
return pytree.tree_unflatten(merged_outs, true_out_spec)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -451,8 +453,10 @@ def check_tensor_meta_match(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _merge_tensors(
|
def _merge_output(
|
||||||
a: Optional[torch.Tensor], b: Optional[torch.Tensor], mode: FakeTensorMode
|
a: Optional[Union[torch.Tensor, int]],
|
||||||
|
b: Optional[Union[torch.Tensor, int]],
|
||||||
|
mode: FakeTensorMode,
|
||||||
):
|
):
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
has_free_unbacked_symbols,
|
has_free_unbacked_symbols,
|
||||||
|
|
@ -463,6 +467,28 @@ def _merge_tensors(
|
||||||
assert a is None and b is None, (a, b)
|
assert a is None and b is None, (a, b)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def min_max(s0, s1):
|
||||||
|
def _bound(s0, lower_bound: bool):
|
||||||
|
if isinstance(s0, int):
|
||||||
|
return s0
|
||||||
|
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
||||||
|
s0.node.expr,
|
||||||
|
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
||||||
|
)
|
||||||
|
return r.lower if lower_bound else r.upper
|
||||||
|
|
||||||
|
return min(_bound(s0, True), _bound(s1, True)), max(
|
||||||
|
_bound(s0, False), _bound(s1, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if type(a) is int and type(b) is int:
|
||||||
|
if a == b:
|
||||||
|
return a
|
||||||
|
assert mode.shape_env is not None
|
||||||
|
merged_out = mode.shape_env.create_unbacked_symint()
|
||||||
|
mode.shape_env.constrain_symbol_range(merged_out.node.expr, *min_max(a, b))
|
||||||
|
return merged_out
|
||||||
|
|
||||||
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
assert type(a) is FakeTensor and type(b) is FakeTensor, (a, type(a), b, type(b))
|
||||||
|
|
||||||
# Note: we don't check size, stride because
|
# Note: we don't check size, stride because
|
||||||
|
|
@ -517,21 +543,6 @@ def _merge_tensors(
|
||||||
):
|
):
|
||||||
merged_size.append(s0)
|
merged_size.append(s0)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def min_max(s0, s1):
|
|
||||||
def _bound(s0, lower_bound: bool):
|
|
||||||
if isinstance(s0, int):
|
|
||||||
return s0
|
|
||||||
r = mode.shape_env.var_to_range.get( # type: ignore[union-attr]
|
|
||||||
s0.node.expr,
|
|
||||||
torch.utils._sympy.value_ranges.ValueRanges.unknown(),
|
|
||||||
)
|
|
||||||
return r.lower if lower_bound else r.upper
|
|
||||||
|
|
||||||
return min(_bound(s0, True), _bound(s1, True)), max(
|
|
||||||
_bound(s0, False), _bound(s1, False)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mode.shape_env is not None
|
assert mode.shape_env is not None
|
||||||
new_size = mode.shape_env.create_unbacked_symint()
|
new_size = mode.shape_env.create_unbacked_symint()
|
||||||
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
mode.shape_env.constrain_symbol_range(new_size.node.expr, *min_max(s0, s1))
|
||||||
|
|
|
||||||
|
|
@ -792,6 +792,9 @@ def check_input_alias_and_mutation_return_ouputs(
|
||||||
# has a persistent fake mode but fake tensors can be created
|
# has a persistent fake mode but fake tensors can be created
|
||||||
# outside of the tracing context (e.g. in testing).
|
# outside of the tracing context (e.g. in testing).
|
||||||
# Instead, we just look at fake_args fake tensor mode
|
# Instead, we just look at fake_args fake tensor mode
|
||||||
|
if len(fake_args) == 0:
|
||||||
|
return torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
||||||
|
|
||||||
prev_fake_mode = None
|
prev_fake_mode = None
|
||||||
for arg in fake_args:
|
for arg in fake_args:
|
||||||
if isinstance(arg, torch.Tensor):
|
if isinstance(arg, torch.Tensor):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user