1234567891011121314 |
- def load_matched_state_dict(model, state_dict, print_stats=True):
- """
- Only loads weights that matched in key and shape. Ignore other weights.
- """
- num_matched, num_total = 0, 0
- curr_state_dict = model.state_dict()
- for key in curr_state_dict.keys():
- num_total += 1
- if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
- curr_state_dict[key] = state_dict[key]
- num_matched += 1
- model.load_state_dict(curr_state_dict)
- if print_stats:
- print(f'Loaded state_dict: {num_matched}/{num_total} matched')
|