Spaces:
Running
on
Zero
Running
on
Zero
Add approximate state persistence (#73)
Browse filesSummary:
Test Plan:
***
More verbose multiprocess logging, fix get_state_and_recycle
Summary:
Test Plan:
- bytelatent/args.py +8 -2
- bytelatent/data/iterators/multiprocess_iterator.py +150 -54
- bytelatent/train.py +41 -21
bytelatent/args.py
CHANGED
|
@@ -13,7 +13,10 @@ from bytelatent.data.file_util import get_fs
|
|
| 13 |
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
| 14 |
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
| 15 |
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
| 16 |
-
from bytelatent.data.iterators.multiprocess_iterator import
|
|
|
|
|
|
|
|
|
|
| 17 |
from bytelatent.data.iterators.packing_iterator import (
|
| 18 |
PackingArgs,
|
| 19 |
PackingIterator,
|
|
@@ -130,6 +133,7 @@ class DataloaderArgs(BaseModel):
|
|
| 130 |
add_bos: bool = True
|
| 131 |
add_eos: bool = True
|
| 132 |
load_async: bool = True
|
|
|
|
| 133 |
prefetch_size: int = 64
|
| 134 |
preprocess_dir: str | None = None
|
| 135 |
dataset_files: list[str] | None = None
|
|
@@ -215,7 +219,9 @@ class DataloaderArgs(BaseModel):
|
|
| 215 |
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
| 216 |
if self.load_async:
|
| 217 |
mp_iterator = MultiprocessIterator(
|
| 218 |
-
packing_iterator,
|
|
|
|
|
|
|
| 219 |
)
|
| 220 |
return mp_iterator
|
| 221 |
else:
|
|
|
|
| 13 |
from bytelatent.data.iterators.abstract_iterator import StatefulIterator
|
| 14 |
from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
|
| 15 |
from bytelatent.data.iterators.looping_iterator import LoopingIterator
|
| 16 |
+
from bytelatent.data.iterators.multiprocess_iterator import (
|
| 17 |
+
MultiprocessIterator,
|
| 18 |
+
PersistType,
|
| 19 |
+
)
|
| 20 |
from bytelatent.data.iterators.packing_iterator import (
|
| 21 |
PackingArgs,
|
| 22 |
PackingIterator,
|
|
|
|
| 133 |
add_bos: bool = True
|
| 134 |
add_eos: bool = True
|
| 135 |
load_async: bool = True
|
| 136 |
+
async_persist_type: PersistType = PersistType.EXACT
|
| 137 |
prefetch_size: int = 64
|
| 138 |
preprocess_dir: str | None = None
|
| 139 |
dataset_files: list[str] | None = None
|
|
|
|
| 219 |
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
| 220 |
if self.load_async:
|
| 221 |
mp_iterator = MultiprocessIterator(
|
| 222 |
+
packing_iterator,
|
| 223 |
+
n_batches_to_prefetch=self.prefetch_size,
|
| 224 |
+
persist_type=self.async_persist_type,
|
| 225 |
)
|
| 226 |
return mp_iterator
|
| 227 |
else:
|
bytelatent/data/iterators/multiprocess_iterator.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import multiprocessing as mp
|
|
|
|
| 5 |
from multiprocessing.synchronize import Event as EventClass
|
| 6 |
from queue import Empty, Full
|
| 7 |
|
|
@@ -19,11 +20,17 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
|
| 19 |
logger = logging.getLogger()
|
| 20 |
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class MultiprocessIteratorState(PydanticIteratorState):
|
| 23 |
model_config = ConfigDict(extra="forbid")
|
| 24 |
base_iterator_state: PackingIteratorState
|
| 25 |
n_batches_to_prefetch: int
|
| 26 |
serialized_prefetch_buffer: str
|
|
|
|
| 27 |
|
| 28 |
def build(self):
|
| 29 |
base_iterator = self.base_iterator_state.build()
|
|
@@ -33,14 +40,19 @@ class MultiprocessIteratorState(PydanticIteratorState):
|
|
| 33 |
base_iterator,
|
| 34 |
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 35 |
prefetch_buffer=prefetch_buffer,
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
|
| 39 |
def start_work_from_state(
|
| 40 |
batch_queue: mp.Queue,
|
| 41 |
state_queue: mp.Queue,
|
|
|
|
| 42 |
stop_event: EventClass,
|
| 43 |
state_dumped_event: EventClass,
|
|
|
|
|
|
|
|
|
|
| 44 |
state: IteratorState,
|
| 45 |
):
|
| 46 |
logging.info("Worker thread: Starting base_iterator work")
|
|
@@ -49,6 +61,25 @@ def start_work_from_state(
|
|
| 49 |
for item in iterator:
|
| 50 |
while not stop_event.is_set():
|
| 51 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
| 53 |
batch_queue.put(item, timeout=0.1)
|
| 54 |
# On success, stop trying
|
|
@@ -58,10 +89,10 @@ def start_work_from_state(
|
|
| 58 |
if stop_event.is_set():
|
| 59 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
| 60 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
| 61 |
-
logging.
|
| 62 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
| 63 |
)
|
| 64 |
-
logging.
|
| 65 |
batch_queue.put(
|
| 66 |
Batch(
|
| 67 |
x=np.zeros((1, 1)),
|
|
@@ -72,23 +103,26 @@ def start_work_from_state(
|
|
| 72 |
ngram_ids=None,
|
| 73 |
)
|
| 74 |
)
|
| 75 |
-
logging.
|
| 76 |
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
| 77 |
)
|
| 78 |
break
|
| 79 |
|
| 80 |
try:
|
| 81 |
-
logging.
|
| 82 |
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
| 83 |
-
logging.
|
| 84 |
state_dumped_event.set()
|
| 85 |
-
logging.
|
| 86 |
except Full:
|
| 87 |
raise ValueError(
|
| 88 |
"Attempted to dump state into the state queue, but it was full"
|
| 89 |
)
|
| 90 |
|
| 91 |
|
|
|
|
|
|
|
|
|
|
| 92 |
class MultiprocessIterator(StatefulIterator):
|
| 93 |
"""
|
| 94 |
Design sketch of the multiprocess iterator:
|
|
@@ -124,18 +158,24 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 124 |
base_iterator: StatefulIterator,
|
| 125 |
*,
|
| 126 |
n_batches_to_prefetch: int,
|
| 127 |
-
prefetch_buffer: list | None = None
|
|
|
|
| 128 |
):
|
| 129 |
self.base_iterator = base_iterator
|
| 130 |
self.n_batches_to_prefetch = n_batches_to_prefetch
|
|
|
|
| 131 |
if prefetch_buffer is None:
|
| 132 |
prefetch_buffer = []
|
| 133 |
self.prefetch_buffer = prefetch_buffer
|
| 134 |
self.batch_queue = None
|
| 135 |
self.state_queue = None
|
|
|
|
| 136 |
self.producer = None
|
| 137 |
self.stop_iterating_event = None
|
| 138 |
self.state_dumped_event = None
|
|
|
|
|
|
|
|
|
|
| 139 |
self.force_shutdown = False
|
| 140 |
|
| 141 |
def shutdown(self):
|
|
@@ -144,6 +184,92 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 144 |
self.producer.kill()
|
| 145 |
self.force_shutdown = True
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def get_state(self) -> MultiprocessIteratorState:
|
| 148 |
"""
|
| 149 |
This is slightly unusual in effectively destroying the current iterator, its necessary
|
|
@@ -162,55 +288,15 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 162 |
base_iterator_state=self.base_iterator.get_state(),
|
| 163 |
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 164 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
|
|
|
| 165 |
)
|
| 166 |
else:
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
final_batch_received = False
|
| 174 |
-
while True:
|
| 175 |
-
try:
|
| 176 |
-
batch = self.batch_queue.get(timeout=1)
|
| 177 |
-
if batch.is_final:
|
| 178 |
-
logging.debug(
|
| 179 |
-
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
| 180 |
-
)
|
| 181 |
-
final_batch_received = True
|
| 182 |
-
break
|
| 183 |
-
self.prefetch_buffer.append(batch)
|
| 184 |
-
except Empty:
|
| 185 |
-
logging.warning("Main thread: batch_queue is abnormally empty")
|
| 186 |
-
assert final_batch_received
|
| 187 |
-
|
| 188 |
-
logging.debug("Main thread: Waiting for state_dumped event")
|
| 189 |
-
self.state_dumped_event.wait()
|
| 190 |
-
|
| 191 |
-
try:
|
| 192 |
-
base_iterator_state = self.state_queue.get(timeout=1)
|
| 193 |
-
assert isinstance(base_iterator_state, IteratorState)
|
| 194 |
-
except Empty:
|
| 195 |
-
raise ValueError(
|
| 196 |
-
"Attempted to get the state, but it was unexpectantly missing"
|
| 197 |
-
)
|
| 198 |
-
|
| 199 |
-
self.base_iterator = base_iterator_state.build()
|
| 200 |
-
self.producer.close()
|
| 201 |
-
self.producer = None
|
| 202 |
-
self.batch_queue = None
|
| 203 |
-
self.state_queue = None
|
| 204 |
-
self.stop_iterating_event = None
|
| 205 |
-
self.state_dumped_event = None
|
| 206 |
-
|
| 207 |
-
return MultiprocessIteratorState(
|
| 208 |
-
base_iterator_state=self.base_iterator.get_state(),
|
| 209 |
-
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 210 |
-
serialized_prefetch_buffer=json.dumps(
|
| 211 |
-
[b.to_python_dict() for b in self.prefetch_buffer]
|
| 212 |
-
),
|
| 213 |
-
)
|
| 214 |
|
| 215 |
def create_iter(self):
|
| 216 |
if self.force_shutdown:
|
|
@@ -236,8 +322,14 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 236 |
# We should only ever one state, which is output at the detection of a stop event
|
| 237 |
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
| 238 |
|
|
|
|
|
|
|
|
|
|
| 239 |
self.stop_iterating_event = ctx.Event()
|
| 240 |
self.state_dumped_event = ctx.Event()
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
self.producer = mp.Process(
|
| 243 |
name="blt_data_loader",
|
|
@@ -245,8 +337,12 @@ class MultiprocessIterator(StatefulIterator):
|
|
| 245 |
args=(
|
| 246 |
self.batch_queue,
|
| 247 |
self.state_queue,
|
|
|
|
| 248 |
self.stop_iterating_event,
|
| 249 |
self.state_dumped_event,
|
|
|
|
|
|
|
|
|
|
| 250 |
self.base_iterator.get_state(),
|
| 251 |
),
|
| 252 |
)
|
|
|
|
| 2 |
import json
|
| 3 |
import logging
|
| 4 |
import multiprocessing as mp
|
| 5 |
+
from enum import Enum
|
| 6 |
from multiprocessing.synchronize import Event as EventClass
|
| 7 |
from queue import Empty, Full
|
| 8 |
|
|
|
|
| 20 |
logger = logging.getLogger()
|
| 21 |
|
| 22 |
|
| 23 |
+
class PersistType(str, Enum):
|
| 24 |
+
EXACT = "exact"
|
| 25 |
+
APPROXIMATE = "approximate"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
class MultiprocessIteratorState(PydanticIteratorState):
|
| 29 |
model_config = ConfigDict(extra="forbid")
|
| 30 |
base_iterator_state: PackingIteratorState
|
| 31 |
n_batches_to_prefetch: int
|
| 32 |
serialized_prefetch_buffer: str
|
| 33 |
+
persist_type: PersistType
|
| 34 |
|
| 35 |
def build(self):
|
| 36 |
base_iterator = self.base_iterator_state.build()
|
|
|
|
| 40 |
base_iterator,
|
| 41 |
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 42 |
prefetch_buffer=prefetch_buffer,
|
| 43 |
+
persist_type=self.persist_type,
|
| 44 |
)
|
| 45 |
|
| 46 |
|
| 47 |
def start_work_from_state(
|
| 48 |
batch_queue: mp.Queue,
|
| 49 |
state_queue: mp.Queue,
|
| 50 |
+
approximate_state_queue: mp.Queue,
|
| 51 |
stop_event: EventClass,
|
| 52 |
state_dumped_event: EventClass,
|
| 53 |
+
trigger_approximate_send_state_event: EventClass,
|
| 54 |
+
sent_approximate_state_event: EventClass,
|
| 55 |
+
received_approximate_state_event: EventClass,
|
| 56 |
state: IteratorState,
|
| 57 |
):
|
| 58 |
logging.info("Worker thread: Starting base_iterator work")
|
|
|
|
| 61 |
for item in iterator:
|
| 62 |
while not stop_event.is_set():
|
| 63 |
try:
|
| 64 |
+
if trigger_approximate_send_state_event.is_set():
|
| 65 |
+
logger.info("WT: trigger_approximate_send ack")
|
| 66 |
+
# Since this can be triggered again (but only after the state is received on mp),
|
| 67 |
+
# we should cleanup as soon as possible.
|
| 68 |
+
trigger_approximate_send_state_event.clear()
|
| 69 |
+
logging.info("WT: Computing approximate state")
|
| 70 |
+
approximate_state = stateful_iterator.get_state()
|
| 71 |
+
# At this state, there should always be exactly 1 slot.
|
| 72 |
+
# Blocking here would be a bug.
|
| 73 |
+
logger.info("WT: Attempting to send approximate state")
|
| 74 |
+
approximate_state_queue.put(
|
| 75 |
+
approximate_state, block=True, timeout=None
|
| 76 |
+
)
|
| 77 |
+
sent_approximate_state_event.set()
|
| 78 |
+
logger.info("WT: Approximate state sent")
|
| 79 |
+
# Same here, clear events as we no longer need them.
|
| 80 |
+
received_approximate_state_event.wait()
|
| 81 |
+
received_approximate_state_event.clear()
|
| 82 |
+
logger.info("WT: State received by MT, resuming batch iteration")
|
| 83 |
# Attempt to put on queue or timeout to try again (maybe main thread is busy)
|
| 84 |
batch_queue.put(item, timeout=0.1)
|
| 85 |
# On success, stop trying
|
|
|
|
| 89 |
if stop_event.is_set():
|
| 90 |
# Signal the end of output, this ensures that even if the queue takes a while to
|
| 91 |
# buffer, that the main thread receives everything (and tosses this fake batch)
|
| 92 |
+
logging.info(
|
| 93 |
"Worker thread: Stop event detected, outputting is_final=True batch"
|
| 94 |
)
|
| 95 |
+
logging.info("Worker thread: batch_queue full=%s", batch_queue.full())
|
| 96 |
batch_queue.put(
|
| 97 |
Batch(
|
| 98 |
x=np.zeros((1, 1)),
|
|
|
|
| 103 |
ngram_ids=None,
|
| 104 |
)
|
| 105 |
)
|
| 106 |
+
logging.info(
|
| 107 |
"Worker thread: is_final=True batch put in queue, breaking from loop."
|
| 108 |
)
|
| 109 |
break
|
| 110 |
|
| 111 |
try:
|
| 112 |
+
logging.info("Worker thread: outputting state")
|
| 113 |
state_queue.put(stateful_iterator.get_state(), timeout=1)
|
| 114 |
+
logging.info("Worker thread: state dump complete")
|
| 115 |
state_dumped_event.set()
|
| 116 |
+
logging.info("Worker thread: set state_dump_event")
|
| 117 |
except Full:
|
| 118 |
raise ValueError(
|
| 119 |
"Attempted to dump state into the state queue, but it was full"
|
| 120 |
)
|
| 121 |
|
| 122 |
|
| 123 |
+
FETCH_STATE_TIMEOUT = 120
|
| 124 |
+
|
| 125 |
+
|
| 126 |
class MultiprocessIterator(StatefulIterator):
|
| 127 |
"""
|
| 128 |
Design sketch of the multiprocess iterator:
|
|
|
|
| 158 |
base_iterator: StatefulIterator,
|
| 159 |
*,
|
| 160 |
n_batches_to_prefetch: int,
|
| 161 |
+
prefetch_buffer: list | None = None,
|
| 162 |
+
persist_type: PersistType = PersistType.EXACT,
|
| 163 |
):
|
| 164 |
self.base_iterator = base_iterator
|
| 165 |
self.n_batches_to_prefetch = n_batches_to_prefetch
|
| 166 |
+
self.persist_type = persist_type
|
| 167 |
if prefetch_buffer is None:
|
| 168 |
prefetch_buffer = []
|
| 169 |
self.prefetch_buffer = prefetch_buffer
|
| 170 |
self.batch_queue = None
|
| 171 |
self.state_queue = None
|
| 172 |
+
self.approximate_state_queue = None
|
| 173 |
self.producer = None
|
| 174 |
self.stop_iterating_event = None
|
| 175 |
self.state_dumped_event = None
|
| 176 |
+
self.trigger_approximate_send_state_event = None
|
| 177 |
+
self.sent_approximate_state_event = None
|
| 178 |
+
self.received_approximate_state_event = None
|
| 179 |
self.force_shutdown = False
|
| 180 |
|
| 181 |
def shutdown(self):
|
|
|
|
| 184 |
self.producer.kill()
|
| 185 |
self.force_shutdown = True
|
| 186 |
|
| 187 |
+
def _get_state_exact(self):
|
| 188 |
+
logging.info("Main thread: Sending stop iteration event")
|
| 189 |
+
self.stop_iterating_event.set()
|
| 190 |
+
logging.info(
|
| 191 |
+
"Main thread: Emptying the batch_queue until batch.is_final=True is found."
|
| 192 |
+
)
|
| 193 |
+
self.prefetch_buffer = []
|
| 194 |
+
final_batch_received = False
|
| 195 |
+
while True:
|
| 196 |
+
try:
|
| 197 |
+
batch = self.batch_queue.get(timeout=1)
|
| 198 |
+
if batch.is_final:
|
| 199 |
+
logging.info(
|
| 200 |
+
"Main thread: is_final=True batch found, stopping fetch from batch_queue"
|
| 201 |
+
)
|
| 202 |
+
final_batch_received = True
|
| 203 |
+
break
|
| 204 |
+
self.prefetch_buffer.append(batch)
|
| 205 |
+
except Empty:
|
| 206 |
+
logging.warning("Main thread: batch_queue is abnormally empty")
|
| 207 |
+
assert final_batch_received
|
| 208 |
+
|
| 209 |
+
logging.info("Main thread: Waiting for state_dumped event")
|
| 210 |
+
self.state_dumped_event.wait()
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
logging.info(
|
| 214 |
+
"Main thread: state_dumped_event received, waiting for state from queue"
|
| 215 |
+
)
|
| 216 |
+
base_iterator_state = self.state_queue.get(timeout=FETCH_STATE_TIMEOUT)
|
| 217 |
+
logging.info("Main thread: received state from queue")
|
| 218 |
+
assert isinstance(base_iterator_state, IteratorState)
|
| 219 |
+
except Empty:
|
| 220 |
+
raise ValueError(
|
| 221 |
+
"Attempted to get the state, but it was unexpectantly missing"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
self.base_iterator = base_iterator_state.build()
|
| 225 |
+
self.producer.close()
|
| 226 |
+
self.producer = None
|
| 227 |
+
self.batch_queue = None
|
| 228 |
+
self.state_queue = None
|
| 229 |
+
self.approximate_state_queue = None
|
| 230 |
+
self.stop_iterating_event = None
|
| 231 |
+
self.state_dumped_event = None
|
| 232 |
+
self.trigger_approximate_send_state_event = None
|
| 233 |
+
self.sent_approximate_state_event = None
|
| 234 |
+
self.received_approximate_state_event = None
|
| 235 |
+
|
| 236 |
+
return MultiprocessIteratorState(
|
| 237 |
+
base_iterator_state=self.base_iterator.get_state(),
|
| 238 |
+
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 239 |
+
serialized_prefetch_buffer=json.dumps(
|
| 240 |
+
[b.to_python_dict() for b in self.prefetch_buffer]
|
| 241 |
+
),
|
| 242 |
+
persist_type=self.persist_type,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
def _get_state_approximate(self):
|
| 246 |
+
logging.info("MT: Sending approximate get_state request")
|
| 247 |
+
self.trigger_approximate_send_state_event.set()
|
| 248 |
+
logging.info("MT: Waiting for sent_approximate_state_event")
|
| 249 |
+
self.sent_approximate_state_event.wait()
|
| 250 |
+
logging.info("MT: sent_approximate_state_event ack")
|
| 251 |
+
try:
|
| 252 |
+
logging.info("MT: waiting for approximate state in queue")
|
| 253 |
+
base_iterator_state = self.approximate_state_queue.get(
|
| 254 |
+
timeout=FETCH_STATE_TIMEOUT
|
| 255 |
+
)
|
| 256 |
+
logging.info("MT: approximate state received")
|
| 257 |
+
assert isinstance(base_iterator_state, IteratorState)
|
| 258 |
+
assert self.approximate_state_queue.empty()
|
| 259 |
+
except Empty:
|
| 260 |
+
raise ValueError(
|
| 261 |
+
"Attempted to get approximate state, but queue was erroniously empty."
|
| 262 |
+
)
|
| 263 |
+
self.received_approximate_state_event.set()
|
| 264 |
+
return MultiprocessIteratorState(
|
| 265 |
+
base_iterator_state=base_iterator_state,
|
| 266 |
+
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 267 |
+
serialized_prefetch_buffer=json.dumps(
|
| 268 |
+
[b.to_python_dict() for b in self.prefetch_buffer]
|
| 269 |
+
),
|
| 270 |
+
persist_type=self.persist_type,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
def get_state(self) -> MultiprocessIteratorState:
|
| 274 |
"""
|
| 275 |
This is slightly unusual in effectively destroying the current iterator, its necessary
|
|
|
|
| 288 |
base_iterator_state=self.base_iterator.get_state(),
|
| 289 |
n_batches_to_prefetch=self.n_batches_to_prefetch,
|
| 290 |
serialized_prefetch_buffer=serialized_prefetch_buffer,
|
| 291 |
+
persist_type=self.persist_type,
|
| 292 |
)
|
| 293 |
else:
|
| 294 |
+
if self.persist_type == PersistType.EXACT:
|
| 295 |
+
return self._get_state_exact()
|
| 296 |
+
elif self.persist_type == PersistType.APPROXIMATE:
|
| 297 |
+
return self._get_state_approximate()
|
| 298 |
+
else:
|
| 299 |
+
raise ValueError("invalid persist_type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
def create_iter(self):
|
| 302 |
if self.force_shutdown:
|
|
|
|
| 322 |
# We should only ever one state, which is output at the detection of a stop event
|
| 323 |
self.state_queue = ctx.Manager().Queue(maxsize=1)
|
| 324 |
|
| 325 |
+
# Similarly, there should only ever be one state in flight due to event signals
|
| 326 |
+
self.approximate_state_queue = ctx.Manager().Queue(maxsize=1)
|
| 327 |
+
|
| 328 |
self.stop_iterating_event = ctx.Event()
|
| 329 |
self.state_dumped_event = ctx.Event()
|
| 330 |
+
self.trigger_approximate_send_state_event = ctx.Event()
|
| 331 |
+
self.sent_approximate_state_event = ctx.Event()
|
| 332 |
+
self.received_approximate_state_event = ctx.Event()
|
| 333 |
|
| 334 |
self.producer = mp.Process(
|
| 335 |
name="blt_data_loader",
|
|
|
|
| 337 |
args=(
|
| 338 |
self.batch_queue,
|
| 339 |
self.state_queue,
|
| 340 |
+
self.approximate_state_queue,
|
| 341 |
self.stop_iterating_event,
|
| 342 |
self.state_dumped_event,
|
| 343 |
+
self.trigger_approximate_send_state_event,
|
| 344 |
+
self.sent_approximate_state_event,
|
| 345 |
+
self.received_approximate_state_event,
|
| 346 |
self.base_iterator.get_state(),
|
| 347 |
),
|
| 348 |
)
|
bytelatent/train.py
CHANGED
|
@@ -31,6 +31,7 @@ from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
|
|
| 31 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
| 32 |
MultiprocessIterator,
|
| 33 |
MultiprocessIteratorState,
|
|
|
|
| 34 |
)
|
| 35 |
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
| 36 |
from bytelatent.distributed import (
|
|
@@ -712,9 +713,15 @@ def train(args: TrainArgs):
|
|
| 712 |
if every_n_steps(
|
| 713 |
train_state, args.checkpoint.dump.every, acc_step=0
|
| 714 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
saved = checkpoint.save(
|
| 719 |
model,
|
| 720 |
optimizer,
|
|
@@ -756,9 +763,16 @@ def train(args: TrainArgs):
|
|
| 756 |
|
| 757 |
if preemption_flag["flag"]:
|
| 758 |
if not saved:
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
checkpoint.save(
|
| 763 |
model,
|
| 764 |
optimizer,
|
|
@@ -769,21 +783,27 @@ def train(args: TrainArgs):
|
|
| 769 |
requeue_slurm_job()
|
| 770 |
sys.exit(0)
|
| 771 |
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
|
| 788 |
|
| 789 |
def main():
|
|
|
|
| 31 |
from bytelatent.data.iterators.multiprocess_iterator import (
|
| 32 |
MultiprocessIterator,
|
| 33 |
MultiprocessIteratorState,
|
| 34 |
+
PersistType,
|
| 35 |
)
|
| 36 |
from bytelatent.data.iterators.packing_iterator import PackingIteratorState
|
| 37 |
from bytelatent.distributed import (
|
|
|
|
| 713 |
if every_n_steps(
|
| 714 |
train_state, args.checkpoint.dump.every, acc_step=0
|
| 715 |
) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
|
| 716 |
+
if (
|
| 717 |
+
args.data.load_async
|
| 718 |
+
and args.data.async_persist_type == PersistType.EXACT
|
| 719 |
+
):
|
| 720 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 721 |
+
get_state_and_refresh(data_loader)
|
| 722 |
+
)
|
| 723 |
+
else:
|
| 724 |
+
train_state.data_loader_state = data_loader.get_state()
|
| 725 |
saved = checkpoint.save(
|
| 726 |
model,
|
| 727 |
optimizer,
|
|
|
|
| 763 |
|
| 764 |
if preemption_flag["flag"]:
|
| 765 |
if not saved:
|
| 766 |
+
if (
|
| 767 |
+
args.data.load_async
|
| 768 |
+
and args.data.async_persist_type == PersistType.EXACT
|
| 769 |
+
):
|
| 770 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 771 |
+
get_state_and_refresh(data_loader)
|
| 772 |
+
)
|
| 773 |
+
else:
|
| 774 |
+
train_state.data_loader_state = data_loader.get_state()
|
| 775 |
+
|
| 776 |
checkpoint.save(
|
| 777 |
model,
|
| 778 |
optimizer,
|
|
|
|
| 783 |
requeue_slurm_job()
|
| 784 |
sys.exit(0)
|
| 785 |
|
| 786 |
+
if not saved:
|
| 787 |
+
if (
|
| 788 |
+
args.data.load_async
|
| 789 |
+
and args.data.async_persist_type == PersistType.EXACT
|
| 790 |
+
):
|
| 791 |
+
train_state.data_loader_state, data_loader, batch_iterator = (
|
| 792 |
+
get_state_and_refresh(data_loader)
|
| 793 |
+
)
|
| 794 |
+
else:
|
| 795 |
+
train_state.data_loader_state = data_loader.get_state()
|
| 796 |
+
checkpoint.save(
|
| 797 |
+
model,
|
| 798 |
+
optimizer,
|
| 799 |
+
train_state,
|
| 800 |
+
args,
|
| 801 |
+
device_mesh=world_mesh,
|
| 802 |
+
)
|
| 803 |
+
if isinstance(data_loader, MultiprocessIterator):
|
| 804 |
+
logger.info("Closing MP iterator before exiting")
|
| 805 |
+
data_loader.shutdown()
|
| 806 |
+
gc.collect()
|
| 807 |
|
| 808 |
|
| 809 |
def main():
|