You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There has been a previous discussion (#484) about other forms of loss calculation during LoRa training in order to support the kind of standard instruction fine-tuning that other systems, like axolotl, allow, such as setting train_on_inputs to false.
Currently, iterate_batches doesn't preserve the distinction between input and output tokens that the completion datasets format has. So, the returned token sequence lengths are for the entire sequence. This makes it straightforward for the dataset API to rely on tokenizer.apply_chat_template to serialize the text according to the model's prompt format.
However, if iterate_batches was modified to keep that distinction and return the length of inputs as well as
batch[:, :-1], batch[:, 1:], mx.array(lengths)
then, this distinction can be preserved for a later loss function to use in masking out the inputs, a very common need in language model instruction tuning.
It would be ideal to continue relying on apply_chat_templates for serializing according to a (semi-)standard chat template, but I think this simple change would open up the door for some low-hanging fruit custom loss calculation functions (such as train_on_inputbehavior).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
There has been a previous discussion (#484) about other forms of loss calculation during LoRa training in order to support the kind of standard instruction fine-tuning that other systems, like axolotl, allow, such as setting
train_on_inputs
to false.Currently, iterate_batches doesn't preserve the distinction between input and output tokens that the completion datasets format has. So, the returned token sequence lengths are for the entire sequence. This makes it straightforward for the dataset API to rely on tokenizer.apply_chat_template to serialize the text according to the model's prompt format.
However, if iterate_batches was modified to keep that distinction and return the length of inputs as well as
then, this distinction can be preserved for a later loss function to use in masking out the inputs, a very common need in language model instruction tuning.
It would be ideal to continue relying on apply_chat_templates for serializing according to a (semi-)standard chat template, but I think this simple change would open up the door for some low-hanging fruit custom loss calculation functions (such as
train_on_input
behavior).Beta Was this translation helpful? Give feedback.
All reactions