|
@@ -20,8 +20,6 @@ import logging
|
|
|
import os
|
|
|
import sys
|
|
|
import pdb
|
|
|
-import yaml
|
|
|
-import dagshub
|
|
|
import subprocess
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
@@ -39,7 +37,6 @@ from transformers import (
|
|
|
EvalPrediction,
|
|
|
HfArgumentParser,
|
|
|
Trainer,
|
|
|
- TrainerCallback,
|
|
|
TrainingArguments,
|
|
|
set_seed,
|
|
|
)
|
|
@@ -96,28 +93,16 @@ class DataTrainingArguments:
|
|
|
)
|
|
|
|
|
|
|
|
|
-class DAGsHubCallback(TrainerCallback):
|
|
|
- def __init__(self, logger):
|
|
|
- super(TrainerCallback, self).__init__()
|
|
|
- self.logger = logger
|
|
|
-
|
|
|
- def on_log(self, args, state, control, logs, model=None, **kwargs):
|
|
|
- if state.is_world_process_zero:
|
|
|
- self.logger.log_metrics({k:v for k,v in logs.items() if isinstance(v, (int, float))}, step_num=state.global_step)
|
|
|
-
|
|
|
-
|
|
|
def main():
|
|
|
# See all possible arguments in src/transformers/training_args.py
|
|
|
# or by passing the --help flag to this script.
|
|
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
|
- if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
|
|
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
|
# If we pass only one argument to the script and it's the path to a json file,
|
|
|
# let's parse it to get our arguments.
|
|
|
- with open(os.path.abspath(sys.argv[1]), "r") as params_file:
|
|
|
- params = yaml.load(params_file)
|
|
|
- model_args, data_args, training_args = parser.parse_dict(params)
|
|
|
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
|
else:
|
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
|
@@ -226,7 +211,7 @@ def main():
|
|
|
|
|
|
out_label_list = [[] for _ in range(batch_size)]
|
|
|
preds_list = [[] for _ in range(batch_size)]
|
|
|
-
|
|
|
+
|
|
|
for i in range(batch_size):
|
|
|
for j in range(seq_len):
|
|
|
if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
|
|
@@ -237,17 +222,13 @@ def main():
|
|
|
|
|
|
def compute_metrics(p: EvalPrediction) -> Dict:
|
|
|
preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
|
|
|
-
|
|
|
+
|
|
|
return {
|
|
|
"precision": precision_score(out_label_list, preds_list),
|
|
|
"recall": recall_score(out_label_list, preds_list),
|
|
|
"f1": f1_score(out_label_list, preds_list),
|
|
|
}
|
|
|
|
|
|
- # Prevent runs directory from being created https://huggingface.co/transformers/v3.5.1/main_classes/trainer.html#transformers.TFTrainingArguments
|
|
|
- training_args.logging_dir = None
|
|
|
-
|
|
|
-
|
|
|
# Initialize our Trainer
|
|
|
trainer = Trainer(
|
|
|
model=model,
|
|
@@ -257,9 +238,6 @@ def main():
|
|
|
compute_metrics=compute_metrics,
|
|
|
)
|
|
|
|
|
|
- dags_logger = dagshub.DAGsHubLogger(should_log_hparams=False)
|
|
|
- trainer.add_callback(DAGsHubCallback(dags_logger))
|
|
|
-
|
|
|
# Training
|
|
|
if training_args.do_train:
|
|
|
trainer.train(
|
|
@@ -277,11 +255,18 @@ def main():
|
|
|
logger.info("*** Evaluate ***")
|
|
|
|
|
|
result = trainer.evaluate()
|
|
|
-
|
|
|
- output_eval_file = os.path.join(training_args.output_dir, "metrics.yaml")
|
|
|
+
|
|
|
+ output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
|
|
|
if trainer.is_world_master():
|
|
|
- results.update(result)
|
|
|
+ with open(output_eval_file, "w") as writer:
|
|
|
+ logger.info("***** Eval results *****")
|
|
|
+ for key, value in result.items():
|
|
|
+ logger.info(" %s = %s", key, value)
|
|
|
+ writer.write("%s = %s\n" % (key, value))
|
|
|
|
|
|
+ results.update(result)
|
|
|
+
|
|
|
+
|
|
|
# Predict
|
|
|
if training_args.do_predict:
|
|
|
test_dataset = NerDataset(
|
|
@@ -296,7 +281,7 @@ def main():
|
|
|
|
|
|
predictions, label_ids, metrics = trainer.predict(test_dataset)
|
|
|
preds_list, _ = align_predictions(predictions, label_ids)
|
|
|
-
|
|
|
+
|
|
|
# Save predictions
|
|
|
output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
|
|
|
if trainer.is_world_master():
|
|
@@ -306,7 +291,7 @@ def main():
|
|
|
logger.info(" %s = %s", key, value)
|
|
|
writer.write("%s = %s\n" % (key, value))
|
|
|
|
|
|
-
|
|
|
+
|
|
|
output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
|
|
|
if trainer.is_world_master():
|
|
|
with open(output_test_predictions_file, "w") as writer:
|
|
@@ -329,9 +314,7 @@ def main():
|
|
|
logger.warning(
|
|
|
"Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
|
|
|
)
|
|
|
-
|
|
|
- dags_logger.save()
|
|
|
- dags_logger.close()
|
|
|
+
|
|
|
|
|
|
return results
|
|
|
|