1 changed files with 17 additions and 34 deletions
  1. 17
    34
      named-entity-recognition/run_ner.py

+ 17
- 34
named-entity-recognition/run_ner.py

@@ -20,8 +20,6 @@ import logging
 import os
 import sys
 import pdb
-import yaml
-import dagshub
 import subprocess
 
 from dataclasses import dataclass, field
@@ -39,7 +37,6 @@ from transformers import (
     EvalPrediction,
     HfArgumentParser,
     Trainer,
-    TrainerCallback,
     TrainingArguments,
     set_seed,
 )
@@ -96,28 +93,16 @@ class DataTrainingArguments:
     )
 
 
-class DAGsHubCallback(TrainerCallback):
-    def __init__(self, logger):
-        super(TrainerCallback, self).__init__()
-        self.logger = logger
-
-    def on_log(self, args, state, control, logs, model=None, **kwargs):
-        if state.is_world_process_zero:
-            self.logger.log_metrics({k:v for k,v in logs.items() if isinstance(v, (int, float))}, step_num=state.global_step)
-
-
 def main():
     # See all possible arguments in src/transformers/training_args.py
     # or by passing the --help flag to this script.
     # We now keep distinct sets of args, for a cleaner separation of concerns.
 
     parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
-    if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
+    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
         # If we pass only one argument to the script and it's the path to a json file,
         # let's parse it to get our arguments.
-        with open(os.path.abspath(sys.argv[1]), "r") as params_file:
-            params = yaml.load(params_file)
-        model_args, data_args, training_args = parser.parse_dict(params)
+        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
     else:
         model_args, data_args, training_args = parser.parse_args_into_dataclasses()
 
@@ -226,7 +211,7 @@ def main():
 
         out_label_list = [[] for _ in range(batch_size)]
         preds_list = [[] for _ in range(batch_size)]
-
+        
         for i in range(batch_size):
             for j in range(seq_len):
                 if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
@@ -237,17 +222,13 @@ def main():
 
     def compute_metrics(p: EvalPrediction) -> Dict:
         preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
-
+        
         return {
             "precision": precision_score(out_label_list, preds_list),
             "recall": recall_score(out_label_list, preds_list),
             "f1": f1_score(out_label_list, preds_list),
         }
 
-    # Prevent runs directory from being created https://huggingface.co/transformers/v3.5.1/main_classes/trainer.html#transformers.TFTrainingArguments
-    training_args.logging_dir = None
-
-
     # Initialize our Trainer
     trainer = Trainer(
         model=model,
@@ -257,9 +238,6 @@ def main():
         compute_metrics=compute_metrics,
     )
 
-    dags_logger = dagshub.DAGsHubLogger(should_log_hparams=False)
-    trainer.add_callback(DAGsHubCallback(dags_logger))
-
     # Training
     if training_args.do_train:
         trainer.train(
@@ -277,11 +255,18 @@ def main():
         logger.info("*** Evaluate ***")
 
         result = trainer.evaluate()
-
-        output_eval_file = os.path.join(training_args.output_dir, "metrics.yaml")
+        
+        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
         if trainer.is_world_master():
-            results.update(result)
+            with open(output_eval_file, "w") as writer:
+                logger.info("***** Eval results *****")
+                for key, value in result.items():
+                    logger.info("  %s = %s", key, value)
+                    writer.write("%s = %s\n" % (key, value))
 
+            results.update(result)
+    
+    
     # Predict
     if training_args.do_predict:
         test_dataset = NerDataset(
@@ -296,7 +281,7 @@ def main():
 
         predictions, label_ids, metrics = trainer.predict(test_dataset)
         preds_list, _ = align_predictions(predictions, label_ids)
-
+        
         # Save predictions
         output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
         if trainer.is_world_master():
@@ -306,7 +291,7 @@ def main():
                     logger.info("  %s = %s", key, value)
                     writer.write("%s = %s\n" % (key, value))
 
-
+        
         output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
         if trainer.is_world_master():
             with open(output_test_predictions_file, "w") as writer:
@@ -329,9 +314,7 @@ def main():
                             logger.warning(
                                 "Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]
                             )
-
-    dags_logger.save()
-    dags_logger.close()
+            
 
     return results
 
Tip!

Press p or to see the previous file or, n or to see the next file