Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#284 Fix training prints

Merged
GitHub User merged 1 commits into Deci-AI:master from deci-ai:hotfix/SG-000-fix_training_prints
@@ -850,9 +850,10 @@ class SgModel:
 
 
         # Instantiate the values to monitor (loss/metric)
         # Instantiate the values to monitor (loss/metric)
         for loss in self.loss_logging_items_names:
         for loss in self.loss_logging_items_names:
-            self.train_monitored_values[loss] = MonitoredValue()
-            self.valid_monitored_values[loss] = MonitoredValue()
-        self.valid_monitored_values[self.metric_to_watch] = MonitoredValue()
+            self.train_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
+            self.valid_monitored_values[loss] = MonitoredValue(name=loss, greater_is_better=False)
+        self.valid_monitored_values[self.metric_to_watch] = MonitoredValue(name=self.metric_to_watch,
+                                                                           greater_is_better=True)
 
 
         # Allowing loading instantiated loss or string
         # Allowing loading instantiated loss or string
         if isinstance(self.training_params.loss, str):
         if isinstance(self.training_params.loss, str):
Discard
@@ -28,25 +28,42 @@ class MonitoredValue:
 
 
     The value can be a metric/loss, and the iteration can be epochs/batch.
     The value can be a metric/loss, and the iteration can be epochs/batch.
     """
     """
+    name: str
+    greater_is_better: bool
     current: float = None
     current: float = None
     previous: float = None
     previous: float = None
     best: float = None
     best: float = None
     change_from_previous: float = None
     change_from_previous: float = None
     change_from_best: float = None
     change_from_best: float = None
-    is_better_than_previous: bool = None
-    is_best_value: bool = None
+
+    @property
+    def is_better_than_previous(self):
+        if self.greater_is_better is None or self.change_from_best is None:
+            return None
+        elif self.greater_is_better:
+            return self.change_from_previous >= 0
+        else:
+            return self.change_from_previous < 0
+
+    @property
+    def is_best_value(self):
+        if self.greater_is_better is None or self.change_from_best is None:
+            return None
+        elif self.greater_is_better:
+            return self.change_from_best >= 0
+        else:
+            return self.change_from_best < 0
 
 
 
 
-def update_monitored_value(previous_monitored_value: MonitoredValue, new_value: float,
-                           greater_is_better: bool) -> MonitoredValue:
+def update_monitored_value(previous_monitored_value: MonitoredValue, new_value: float) -> MonitoredValue:
     """Update the given ValueToMonitor object (could be a loss or a metric) with the new value
     """Update the given ValueToMonitor object (could be a loss or a metric) with the new value
 
 
     :param previous_monitored_value: The stats about the value that is monitored throughout epochs.
     :param previous_monitored_value: The stats about the value that is monitored throughout epochs.
     :param new_value: The value of the current epoch that will be used to update previous_monitored_value
     :param new_value: The value of the current epoch that will be used to update previous_monitored_value
-    :param greater_is_better: True when a greater value means better result.
     :return:
     :return:
     """
     """
     previous_value, previous_best_value = previous_monitored_value.current, previous_monitored_value.best
     previous_value, previous_best_value = previous_monitored_value.current, previous_monitored_value.best
+    name, greater_is_better = previous_monitored_value.name, previous_monitored_value.greater_is_better
 
 
     if previous_best_value is None:
     if previous_best_value is None:
         previous_best_value = previous_value
         previous_best_value = previous_value
@@ -58,17 +75,13 @@ def update_monitored_value(previous_monitored_value: MonitoredValue, new_value:
     if previous_value is None:
     if previous_value is None:
         change_from_previous = None
         change_from_previous = None
         change_from_best = None
         change_from_best = None
-        is_better_than_previous = None
-        is_best_value = None
     else:
     else:
         change_from_previous = new_value - previous_value
         change_from_previous = new_value - previous_value
         change_from_best = new_value - previous_best_value
         change_from_best = new_value - previous_best_value
-        is_better_than_previous = change_from_previous >= 0 if greater_is_better else change_from_previous <= 0
-        is_best_value = change_from_best >= 0 if greater_is_better else change_from_best <= 0
 
 
-    return MonitoredValue(current=new_value, previous=previous_value, best=previous_best_value,
+    return MonitoredValue(name=name, current=new_value, previous=previous_value, best=previous_best_value,
                           change_from_previous=change_from_previous, change_from_best=change_from_best,
                           change_from_previous=change_from_previous, change_from_best=change_from_best,
-                          is_better_than_previous=is_better_than_previous, is_best_value=is_best_value)
+                          greater_is_better=greater_is_better)
 
 
 
 
 def update_monitored_values_dict(monitored_values_dict: Dict[str, MonitoredValue],
 def update_monitored_values_dict(monitored_values_dict: Dict[str, MonitoredValue],
@@ -83,7 +96,6 @@ def update_monitored_values_dict(monitored_values_dict: Dict[str, MonitoredValue
         monitored_values_dict[monitored_value_name] = update_monitored_value(
         monitored_values_dict[monitored_value_name] = update_monitored_value(
             new_value=new_values_dict[monitored_value_name],
             new_value=new_values_dict[monitored_value_name],
             previous_monitored_value=monitored_values_dict[monitored_value_name],
             previous_monitored_value=monitored_values_dict[monitored_value_name],
-            greater_is_better=False
         )
         )
     return monitored_values_dict
     return monitored_values_dict
 
 
Discard