Skip to main content

Ordinal output layers and loss functions (Rennie & Srebro, 2005) for PyTorch and TF/Keras

Project description

DeepOrdinal

PyPI version Python 3.10+ License: MIT

Ordinal output layers and loss functions for PyTorch and TensorFlow/Keras, based on Rennie & Srebro (2005).

DeepOrdinal provides an OrdinalOutput layer that converts a learned logit into ordinal class probabilities via sorted thresholds, plus loss functions designed specifically for ordinal regression.

Installation

pip install deepordinal

Install with a specific backend:

pip install "deepordinal[torch]"  # PyTorch
pip install "deepordinal[tf]"     # TensorFlow/Keras

For development:

pip install -e ".[torch,tf]"

Quick Start

PyTorch

import torch
import torch.nn as nn
from deepordinal.torch import OrdinalOutput, ordinal_loss

model = nn.Sequential(
    nn.Linear(8, 16),
    nn.ReLU(),
    OrdinalOutput(input_dim=16, output_dim=4),
)
ordinal_layer = model[2]
optimizer = torch.optim.Adam(model.parameters())

# Training step
h = model[1](model[0](x_batch))              # hidden features
logits = ordinal_layer.linear(h)              # raw logit
thresholds = ordinal_layer.interior_thresholds
loss = ordinal_loss(logits, y_batch, thresholds)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# Inference
probs = model(x_batch)        # (batch, K) class probabilities
preds = probs.argmax(dim=1)

TensorFlow/Keras

import tensorflow as tf
from deepordinal.tf import OrdinalOutput, ordinal_loss

dense = tf.keras.layers.Dense(16, activation="relu")
ordinal = OrdinalOutput(output_dim=4)

# Training step
with tf.GradientTape() as tape:
    h = dense(x_batch)
    logits = tf.matmul(h, ordinal.kernel) + ordinal.bias
    thresholds = ordinal.interior_thresholds[0]  # shape is (1, K-1)
    loss = ordinal_loss(logits, y_batch, thresholds)
grads = tape.gradient(loss, dense.trainable_variables + ordinal.trainable_variables)

# Inference
probs = ordinal(dense(x_batch))  # (batch, K) class probabilities
preds = tf.argmax(probs, axis=1)

API

OrdinalOutput Layer

Projects an input down to a single logit and converts it into K class probabilities using K-1 learned, sorted thresholds:

P(y = k | x) = sigmoid(t(k+1) - logit) - sigmoid(t(k) - logit)

where t(0) = -inf and t(K) = inf are fixed, and interior thresholds are initialized sorted.

PyTorch TensorFlow/Keras
Import from deepordinal.torch import OrdinalOutput from deepordinal.tf import OrdinalOutput
Constructor OrdinalOutput(input_dim, output_dim) OrdinalOutput(output_dim)
Logit access layer.linear(h) tf.matmul(h, layer.kernel) + layer.bias
Thresholds layer.interior_thresholds — shape (K-1,) layer.interior_thresholds[0] — shape (1, K-1)

ordinal_loss

Threshold-based ordinal loss from Rennie & Srebro (2005). Operates on raw logits and thresholds rather than probability output.

ordinal_loss(logits, targets, thresholds, construction="all", penalty="logistic")

Parameters:

  • logits(batch,) or (batch, 1), raw predictor output
  • targets(batch,), integer labels in [0, K)
  • thresholds(K-1,), sorted interior thresholds
  • construction"all" (default) or "immediate"
  • penalty"logistic" (default), "hinge", "smooth_hinge", or "modified_least_squares"

Returns: scalar mean loss over the batch.

Constructions

  • All-threshold (default, eq 13) — penalizes violations of every threshold, weighted by direction. Bounds mean absolute error. Best performer in the paper's experiments.
  • Immediate-threshold (eq 12) — only penalizes violations of the two thresholds bounding the correct class segment.

Penalty Functions

Name Formula Reference
"logistic" log(1 + exp(-z)) eq 9
"hinge" max(0, 1-z) eq 5
"smooth_hinge" 0 if z>=1, (1-z)^2/2 if 0<z<1, 0.5-z if z<=0 eq 6
"modified_least_squares" 0 if z>=1, (1-z)^2 if z<1 eq 7

The paper recommends all-threshold + logistic as the best-performing combination (the default).

ordistic_loss

Probabilistic generalization of logistic regression to K-class ordinal problems (Section 4).

ordistic_loss(logits, targets, means, log_priors=None)

Parameters:

  • logits(batch,) or (batch, 1), raw predictor output
  • targets(batch,), integer labels in [0, K)
  • means(K,), class means (convention: mu_1=-1, mu_K=1; interior means learned)
  • log_priors(K,) or None, optional log-prior terms

Returns: scalar mean negative log-likelihood over the batch.

Examples

Complete training loops with synthetic ordinal data:

Testing

pip install -e ".[torch,tf]"
pytest -v

License

MIT

Citation

Rennie, J. D. M. & Srebro, N. (2005). Loss Functions for Preference Levels: Regression with Discrete Ordered Labels. Proceedings of the IJCAI Multidisciplinary Workshop on Advances in Preference Handling.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

deepordinal-0.2.1.post1.tar.gz (10.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

deepordinal-0.2.1.post1-py3-none-any.whl (8.6 kB view details)

Uploaded Python 3

File details

Details for the file deepordinal-0.2.1.post1.tar.gz.

File metadata

  • Download URL: deepordinal-0.2.1.post1.tar.gz
  • Upload date:
  • Size: 10.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.5

File hashes

Hashes for deepordinal-0.2.1.post1.tar.gz
Algorithm Hash digest
SHA256 c9f44c3e504331e232ece14d6576ddb3ab8872b6bd8da2a19228b6bcbfcebf9c
MD5 80e714820fbdf470507ad2e440d68b00
BLAKE2b-256 de02ff4c59b74c68125a5ad9b83b3a213399dec06a16a8872d225d89fe578b7a

See more details on using hashes here.

File details

Details for the file deepordinal-0.2.1.post1-py3-none-any.whl.

File metadata

File hashes

Hashes for deepordinal-0.2.1.post1-py3-none-any.whl
Algorithm Hash digest
SHA256 8a385fe6c781c5092da1bebcc43f5f5f2ba4ac6abdca11ba3eeb23bf408047ad
MD5 df47f8d5808507d4b5c3d20ee6002295
BLAKE2b-256 8bb7e90e98b844a2607f2ce08617c0309e72474c9dc8d7a715eb473acea87083

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page