utils.py 594 B

1234567891011121314
  1. def load_matched_state_dict(model, state_dict, print_stats=True):
  2. """
  3. Only loads weights that matched in key and shape. Ignore other weights.
  4. """
  5. num_matched, num_total = 0, 0
  6. curr_state_dict = model.state_dict()
  7. for key in curr_state_dict.keys():
  8. num_total += 1
  9. if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
  10. curr_state_dict[key] = state_dict[key]
  11. num_matched += 1
  12. model.load_state_dict(curr_state_dict)
  13. if print_stats:
  14. print(f'Loaded state_dict: {num_matched}/{num_total} matched')