mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61234 * **#61234 [WIP] Adding demux and mux DataPipe API examples** Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D29588836 Pulled By: VitalyFedyunin fbshipit-source-id: 523d12ea6be7507d706b4c6d8827ec1ac4ccabc3
1131 lines
26 KiB
Plaintext
1131 lines
26 KiB
Plaintext
{
|
|
"metadata": {
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.10"
|
|
},
|
|
"orig_nbformat": 2,
|
|
"kernelspec": {
|
|
"name": "python3610jvsc74a57bd0eb5e09632d6ea1cbf3eb9da7e37b7cf581db5ed13074b21cc44e159dc62acdab",
|
|
"display_name": "Python 3.6.10 64-bit ('dataloader': conda)"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2,
|
|
"cells": [
|
|
{
|
|
"source": [
|
|
"## Standard flow control and data processing DataPipes"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.utils.data import IterDataPipe"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Example IterDataPipe\n",
|
|
"class ExampleIterPipe(IterDataPipe):\n",
|
|
" def __init__(self, range = 20):\n",
|
|
" self.range = range\n",
|
|
" def __iter__(self):\n",
|
|
" for i in range(self.range):\n",
|
|
" yield i"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Batch\n",
|
|
"\n",
|
|
"Function: `batch`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" - `batch_size: int` desired batch size\n",
|
|
" - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
|
|
" - `drop_last: bool = False`\n",
|
|
"\n",
|
|
"Example:\n",
|
|
"\n",
|
|
"Classic batching produce partial batches by default\n"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n[9]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"To drop incomplete batches add `drop_last` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3, drop_last = True)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Sequential calling of `batch` produce nested batches"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[[0, 1, 2], [3, 4, 5]]\n[[6, 7, 8], [9, 10, 11]]\n[[12, 13, 14], [15, 16, 17]]\n[[18, 19, 20], [21, 22, 23]]\n[[24, 25, 26], [27, 28, 29]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(30).batch(3).batch(2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"It is possible to unbatch source data before applying the new batching rule using `unbatch_level` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(30).batch(3).batch(2).batch(10, unbatch_level=-1)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Unbatch\n",
|
|
"\n",
|
|
"Function: `unbatch`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" `unbatch_level:int = 1`\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"9\n0\n1\n2\n6\n7\n8\n3\n4\n5\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).shuffle().unbatch()\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"By default unbatching is applied only on the first layer, to unbatch deeper use `unbatch_level` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1]\n[2, 3]\n[4, 5]\n[6, 7]\n[8, 9]\n[10, 11]\n[12, 13]\n[14, 15]\n[16, 17]\n[18, 19]\n[20, 21]\n[22, 23]\n[24, 25]\n[26, 27]\n[28, 29]\n[30, 31]\n[32, 33]\n[34, 35]\n[36, 37]\n[38, 39]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Setting `unbatch_level` to `-1` will unbatch to the lowest level"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"0\n1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n11\n12\n13\n14\n15\n16\n17\n18\n19\n20\n21\n22\n23\n24\n25\n26\n27\n28\n29\n30\n31\n32\n33\n34\n35\n36\n37\n38\n39\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(40).batch(2).batch(4).batch(3).unbatch(unbatch_level = -1)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Map\n",
|
|
"\n",
|
|
"Function: `map`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" - `nesting_level: int = 0`\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"0\n2\n4\n6\n8\n10\n12\n14\n16\n18\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).map(lambda x: x * 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"`map` by default applies function to every mini-batch as a whole\n"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2, 0, 1, 2]\n[3, 4, 5, 3, 4, 5]\n[6, 7, 8, 6, 7, 8]\n[9, 9]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).map(lambda x: x * 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"To apply function on individual items of the mini-batch use `nesting_level` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[[0, 2, 4], [6, 8, 10]]\n[[12, 14, 16], [18]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).batch(2).map(lambda x: x * 2, nesting_level = 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Setting `nesting_level` to `-1` will apply `map` function to the lowest level possible"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[[[0, 2, 4], [6, 8, 10]], [[12, 14, 16], [18]]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).batch(2).batch(2).map(lambda x: x * 2, nesting_level = -1)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Filter\n",
|
|
"\n",
|
|
"Function: `filter`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" - `nesting_level: int = 0`\n",
|
|
" - `drop_empty_batches = True` whether empty many batches dropped or not.\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"0\n2\n4\n6\n8\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).filter(lambda x: x % 2 == 0)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Classic `filter` by default applies filter function to every mini-batches as a whole \n"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2]\n[3, 4, 5]\n[6, 7, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10)\n",
|
|
"dp = dp.batch(3).filter(lambda x: len(x) > 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"You can apply filter function on individual elements by setting `nesting_level` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[5]\n[6, 7, 8]\n[9]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10)\n",
|
|
"dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = 1)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"If mini-batch ends with zero elements after filtering default behaviour would be to drop them from the response. You can override this behaviour using `drop_empty_batches` argument.\n"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[]\n[5]\n[6, 7, 8]\n[9]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10)\n",
|
|
"dp = dp.batch(3).filter(lambda x: x > 4, nesting_level = -1, drop_empty_batches = False)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[[[0, 1, 2], [3]], [[], [10, 11]]]\n[[[12, 13, 14], [15, 16, 17]], [[18, 19]]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(20)\n",
|
|
"dp = dp.batch(3).batch(2).batch(2).filter(lambda x: x < 4 or x > 9 , nesting_level = -1, drop_empty_batches = False)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Shuffle\n",
|
|
"\n",
|
|
"Function: `shuffle`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
|
|
" - `buffer_size: int = 10000`\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"2\n9\n4\n0\n3\n7\n8\n5\n6\n1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).shuffle()\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"`shuffle` operates on input mini-batches similar as on individual items"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 1, 2]\n[3, 4, 5]\n[9]\n[6, 7, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).shuffle()\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"To shuffle elements across batches use `shuffle(unbatch_level)` followed by `batch` pattern "
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[2, 1, 0]\n[7, 9, 6]\n[3, 5, 4]\n[8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).shuffle(unbatch_level = -1).batch(3)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Collate\n",
|
|
"\n",
|
|
"Function: `collate`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"tensor([0, 1, 2])\ntensor([3, 4, 5])\ntensor([6, 7, 8])\ntensor([9])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).collate()\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## GroupBy\n",
|
|
"\n",
|
|
"Function: `groupby`\n",
|
|
"\n",
|
|
"Usage: `dp.groupby(lambda x: x[0])`\n",
|
|
"\n",
|
|
"Description: Batching items by combining items with same key into same batch \n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" - `group_key_fn`\n",
|
|
" - `group_size` - yeild resulted group as soon as `group_size` elements accumulated\n",
|
|
" - `guaranteed_group_size:int = None`\n",
|
|
" - `unbatch_level:int = 0` if specified calls `unbatch(unbatch_level=unbatch_level)` on source datapipe before batching (see `unbatch`)\n",
|
|
"\n",
|
|
"#### Attention\n",
|
|
"As datasteam can be arbitrary large, grouping is done on best effort basis and there is no guarantee that same key will never present in the different groups. You can call it local groupby where locallity is the one DataPipe process/thread."
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 3, 6, 9]\n[1, 4, 7]\n[5, 2, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).shuffle().groupby(lambda x: x % 3)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"By default group key function is applied to entire input (mini-batch)"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[[0, 1, 2], [3, 4, 5], [6, 7, 8]]\n[[9]]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).groupby(lambda x: len(x))\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"It is possible to unnest items from the mini-batches using `unbatch_level` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 3, 6, 9]\n[1, 4, 7]\n[2, 5, 8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10).batch(3).groupby(lambda x: x % 3, unbatch_level = 1)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"When internal buffer (defined by `buffer_size`) is overfilled, groupby will yield biggest group available"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[9, 3]\n[13, 4, 7]\n[2, 11, 14, 5]\n[0, 6, 12]\n[1, 10]\n[8]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, buffer_size = 5)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"`groupby` will produce `group_size` sized batches on as fast as possible basis"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[6, 3, 12]\n[1, 16, 7]\n[2, 5, 8]\n[14, 11, 17]\n[15, 9, 0]\n[10, 4, 13]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(18).shuffle().groupby(lambda x: x % 3, group_size = 3)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Remaining groups must be at least `guaranteed_group_size` big. "
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[11, 2, 5]\n[1, 4, 10]\n[0, 9, 6]\n[14, 8]\n[13, 7]\n[12, 3]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, group_size = 3, guaranteed_group_size = 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"Without defined `group_size` function will try to accumulate at least `guaranteed_group_size` elements before yielding resulted group"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[3, 6, 9, 12, 0]\n[14, 2, 8, 11, 5]\n[7, 4, 1, 13, 10]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"This behaviour becomes noticable when data is bigger than buffer and some groups getting evicted before gathering all potential items"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[0, 3]\n[1, 4, 7]\n[2, 5, 8]\n[6, 9, 12]\n[10, 13]\n[11, 14]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"With randomness involved you might end up with incomplete groups (so next example expected to fail in most cases)"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[14, 5, 11]\n[1, 7, 4, 10]\n[0, 12, 6]\n[8, 2]\n[9, 3]\n"
|
|
]
|
|
},
|
|
{
|
|
"output_type": "error",
|
|
"ename": "Exception",
|
|
"evalue": "('Failed to group items', '[13]')",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mException\u001b[0m Traceback (most recent call last)",
|
|
"\u001b[0;32m<ipython-input-31-673b9dd7fb43>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mdp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mExampleIterPipe\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m15\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuffle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroupby\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mguaranteed_group_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m6\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;32m~/dataset/pytorch/torch/utils/data/datapipes/iter/grouping.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 276\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_remaining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Failed to group items'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbiggest_key\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mbiggest_size\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mguaranteed_group_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|
"\u001b[0;31mException\u001b[0m: ('Failed to group items', '[13]')"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"To avoid this error and drop incomplete groups, use `drop_remaining` argument"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"[5, 2, 14]\n[4, 7, 13, 1, 10]\n[12, 6, 3, 9]\n[8, 11]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(15).shuffle().groupby(lambda x: x % 3, guaranteed_group_size = 2, buffer_size = 6, drop_remaining = True)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Zip\n",
|
|
"\n",
|
|
"Function: `zip`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"(0, 3)\n(1, 0)\n(2, 4)\n(3, 2)\n(4, 1)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"_dp = ExampleIterPipe(5).shuffle()\n",
|
|
"dp = ExampleIterPipe(5).zip(_dp)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Fork\n",
|
|
"\n",
|
|
"Function: `fork`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" \n",
|
|
"Example:"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"0\n1\n0\n1\n0\n1\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(2)\n",
|
|
"dp1, dp2, dp3 = dp.fork(3)\n",
|
|
"for i in dp1 + dp2 + dp3:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Demultiplexer\n",
|
|
"\n",
|
|
"Function: `demux`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" \n",
|
|
"Example:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1\n",
|
|
"4\n",
|
|
"7\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(10)\n",
|
|
"dp1, dp2, dp3 = dp.demux(3, lambda x: x % 3)\n",
|
|
"for i in dp2:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Multiplexer\n",
|
|
"\n",
|
|
"Function: `mux`\n",
|
|
"\n",
|
|
"Description: \n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
"\n",
|
|
"Arguments:\n",
|
|
" \n",
|
|
"Example:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0\n",
|
|
"0\n",
|
|
"0\n",
|
|
"1\n",
|
|
"10\n",
|
|
"100\n",
|
|
"2\n",
|
|
"20\n",
|
|
"200\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp1 = ExampleIterPipe(3)\n",
|
|
"dp2 = ExampleIterPipe(3).map(lambda x: x * 10)\n",
|
|
"dp3 = ExampleIterPipe(3).map(lambda x: x * 100)\n",
|
|
"\n",
|
|
"dp = dp1.mux(dp2, dp3)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
},
|
|
{
|
|
"source": [
|
|
"## Concat\n",
|
|
"\n",
|
|
"Function: `concat`\n",
|
|
"\n",
|
|
"Description: Returns DataPipes with elements from the first datapipe following by elements from second datapipes\n",
|
|
"\n",
|
|
"Alternatives:\n",
|
|
" \n",
|
|
" `dp = dp.concat(dp2, dp3)`\n",
|
|
" `dp = dp.concat(*datapipes_list)`\n",
|
|
"\n",
|
|
"Example:\n"
|
|
],
|
|
"cell_type": "markdown",
|
|
"metadata": {}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"output_type": "stream",
|
|
"name": "stdout",
|
|
"text": [
|
|
"0\n1\n2\n3\n0\n1\n2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dp = ExampleIterPipe(4)\n",
|
|
"dp2 = ExampleIterPipe(3)\n",
|
|
"dp = dp.concat(dp2)\n",
|
|
"for i in dp:\n",
|
|
" print(i)"
|
|
]
|
|
}
|
|
]
|
|
} |