Skip to main content

A group of techniques and methods which can potentially improve visual explainability models based on the Learning to Rank approach

Project description

data_improvement_library

A group of techniques and methods which can potentlially improve visual explainability models based on the Learning to Rank aproach

Usage Guide

Each module in the library requires a set of input variables that define its functionality. This section describes these variables and provides usage examples for each module.

Data Augmentation Module

Input Variables

  • data_dir (str): Path to the input data file containing user interactions.
  • vector_dir (str): Path to the file containing image embeddings.
  • image_dir (str): Path to the directory where images are stored.
  • output_dir (str): Path where the processed data will be stored.
  • output_name (str, optional): Name of the output file. Default: TRAIN_IMG.
  • embedding_model (torch.nn.Module, optional): Pretrained model for embedding extraction. If left empty, a pretrained ViT Large 14 model is used by default.
  • no_aug (bool, optional): If True, disables data augmentation.
  • batch_size (int, optional): Number of images processed per batch. Default: 32.
  • apply_all (bool, optional): If True, applies all available transformations.
  • labels (list, optional): List of column names in the dataset.

Output Variables

This module does not produce a return value. Instead, it saves the processed data in pickle format based on the input variables.

Usage Example

import data_improvement_library

# Process images with data augmentation and embedding extraction
data_improvement_library.augment_data(data_dir="data/user_data.pkl", vector_dir="data/image_vectors.pkl",
             image_dir="images/", output_dir="processed_data/",
             output_name="TRAIN_IMG", embedding_model=None,
             no_aug=False, batch_size=32, apply_all=True, labels=None)

Embedding Generation Module (new_embeddings.py)

Input Variables

  • directory (str): Directory where images are stored.
  • output_dir (str): Directory where generated embeddings will be saved.
  • output_name (str): Name of the output file.
  • embedding_model (torch.nn.Module or None): Pretrained model for embedding extraction. If left empty, a pretrained ViT Large 14 model is used by default.
  • batch_size (int): Number of images processed per batch.

Output Variables

This module does not return a value. Instead, it saves the processed data in pickle format.

Usage Example

import data_improvement_library

# Generate embeddings for a set of images
data_improvement_library.create_new_embeddings(directory="images/", output_dir="embeddings/", output_name="img_vectors.pkl",
                      embedding_model=None, batch_size=32)

Negative Selection with PU Learning (pu_negatives.py)

Input Variables

  • data_file (str): Path to the user interaction data file.
  • vector_file (str): Path to the image embeddings file.
  • outdir_name (str): Output directory where the processed dataset will be saved.
  • centroid (int, optional): Centroid percentile (default: 90).
  • factor (float, optional): Distance adjustment factor.
  • labels (list, optional): Column labels in the dataset.

Output Variables

This module does not return a value. Instead, it saves the processed data in pickle format.

Usage Example

import data_improvement_library

# Generate balanced negative samples
data_improvement_library.resample_negatives(data_file="data/user_data.pkl", vector_file="data/image_vectors.pkl",
                   outdir_name="processed_data/", centroid=90, factor=1.0)

Positive Selection with PU Learning (pu_positives.py)

This module can be used in two main ways:

  1. As a Callback in model training with PyTorch Lightning, allowing dynamic updates of user centroids in each epoch.
  2. As Preprocessing, applied before training to the dataset.

Usage as Callback

The module provides the CallbackEndPositives class, which extends Callback from PyTorch Lightning. Its purpose is to update user centroids and resample positive samples at the end of each training epoch.

For proper functionality, the model must meet the following requirements:

  • Dataset in train_dataloader must contain:
    • datamodule.image_embeddings: Tensor of image embeddings.
    • dataframe: Must include columns id_user, id_img, id_restaurant, take.
  • Embedding access: Must be indexable by id_img in the dataframe.
  • Embedding format: Must be a tensor convertible to NumPy (.cpu().detach().numpy()).
  • Data sampling: Must allow modification of pu_dataset in the dataset.
  • DataModule requirements: Must provide image_embeddings and handle pu_dataset correctly.

Usage Example

import data_improvement_library
from pytorch_lightning import Trainer

# Create an instance of the callback
callback_end = data_improvement_library.CallbackEndPositives()

# Configure the trainer with the callback
trainer = Trainer(callbacks=[callback_end])

# Start training
trainer.fit(model)

Usage as Preprocessing

This module can also generate a dataset with new positive samples selected based on user similarity before training. This is achieved by calculating user centroids and resampling positive samples. The functionality is divided into two functions: centroid_users for computing user centroids and resample_positives for updating the dataset with new positive samples using cosine similarity filtering.

Input Variables

For the function centroid_users:

  • dataframe (pandas.DataFrame): Dataset of user-image pairs.
  • vectors (numpy.array): Image embedding representations.
  • labels (list, optional): List of column names in the dataset.

For the function resample_positives:

  • dataframe (pandas.DataFrame): Dataset of user-image pairs.
  • centroids (dict): Dictionary where keys represent user IDs, and values are user image centroids as embeddings.
  • k (int): Number of similar users considered for adding positive samples per user.

Output Variables

  • centroid_users returns centroids, a dictionary where keys are user IDs and values are user image centroids as embeddings.
  • resample_positives returns the updated dataframe with new positive samples.

Usage Example

import data_improvement_library
import pandas as pd
import numpy as np

# Load the dataset
dataframe = pd.read_pickle("data/user_data.pkl")

# Load image embeddings
image_vectors = np.load("data/image_vectors.npy")

# Compute user centroids
centroids = data_improvement_library.centroid_users(dataframe, image_vectors)

# Generate a new dataset with balanced positive samples
new_dataframe = data_improvement_library.resample_positives(dataframe, centroids, 3)

# Save the processed dataset
new_dataframe.to_pickle("processed_data/balanced_dataset.pkl")

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

data_improvement_library-0.3.2.tar.gz (14.3 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

data_improvement_library-0.3.2-py3-none-any.whl (15.8 kB view details)

Uploaded Python 3

File details

Details for the file data_improvement_library-0.3.2.tar.gz.

File metadata

  • Download URL: data_improvement_library-0.3.2.tar.gz
  • Upload date:
  • Size: 14.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.10.10

File hashes

Hashes for data_improvement_library-0.3.2.tar.gz
Algorithm Hash digest
SHA256 699fe8ad23efc7e673ca145ce622edef39d5743be8b344b2cb26ad817a9deb75
MD5 1757c73b822ac5b541ebd8567ecc8387
BLAKE2b-256 0b535865c44dea1bdee2d4e524ae52f5f7a51a68fdc5b5a70ee47e65fc9a07b2

See more details on using hashes here.

File details

Details for the file data_improvement_library-0.3.2-py3-none-any.whl.

File metadata

File hashes

Hashes for data_improvement_library-0.3.2-py3-none-any.whl
Algorithm Hash digest
SHA256 45bb79453cdd0291d6e6061f1db5ad63bb877922575d7a4799a7f0bbd4fde0b2
MD5 90dbe15d579bb6c07e5c85be3970496f
BLAKE2b-256 c14857828dde6823191c752d0860c107de292497bf092eb51350dd9e22326c52

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page