domain_mapping/analysis/class_imbalance.py

59 lines
1.4 KiB
Python
Raw Permalink Normal View History

# %%
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# %%
# import training file
data_path = '../data_import/train.csv'
train_df = pd.read_csv(data_path, skipinitialspace=True)
# %%
id_counts = train_df['entity_id'].value_counts()
# %%
# %%
id_counts[:50]
# %%
plt.hist(id_counts, bins=50)
# %%
def compute_normalized_class_weights(class_counts, max_resamples=10):
"""
Compute normalized class weights inversely proportional to class counts.
The weights are normalized so that they sum to 1.
Args:
class_counts (array-like): An array or list where each element represents the count of samples for a class.
Returns:
numpy.ndarray: A normalized array of weights for each class.
"""
class_counts = np.array(class_counts)
total_samples = np.sum(class_counts)
class_weights = total_samples / class_counts
# so that highest weight is 1
normalized_weights = class_weights / np.max(class_weights)
# Scale weights such that the highest weight corresponds to `max_resamples`
resample_counts = normalized_weights * max_resamples
# Round resamples to nearest integer
resample_counts = np.round(resample_counts).astype(int)
return resample_counts
# %%
id_weights = compute_normalized_class_weights(id_counts, max_resamples=10)
# %%
id_weights
# %%
id_mask = train_df['entity_id'] == 536
train_df[id_mask]
# %%
id_counts.index.to_list()
# %%