Skip to content

Initial work for new 4.0 API #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deeplc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__all__ = ["DeepLC"]
# __all__ = ["DeepLC"]

from importlib.metadata import version
# from importlib.metadata import version

__version__ = version("deeplc")
# __version__ = version("deeplc")


from deeplc.deeplc import DeepLC
# from deeplc.deeplc import DeepLC
53 changes: 53 additions & 0 deletions deeplc/_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch
from psm_utils import Peptidoform, PSMList
from torch.utils.data import Dataset

from deeplc._features import encode_peptidoform


class DeepLCDataset(Dataset):
"""Custom Dataset class for DeepLC used for loading features from peptide sequences."""

def __init__(
self,
peptidoforms: list[Peptidoform | str],
target_retention_times: np.ndarray | None = None,
add_ccs_features: bool = False
):
self.peptidoforms = peptidoforms
self.target_retention_times = target_retention_times
self.add_ccs_features = add_ccs_features

def __len__(self):
return len(self.peptidoforms)

def __getitem__(self, idx) -> tuple:
if not isinstance(idx, int):
raise TypeError(f"Index must be an integer, got {type(idx)} instead.")
features = encode_peptidoform(
self.peptidoforms[idx],
add_ccs_features=self.add_ccs_features
)
feature_tuples = (
torch.from_numpy(features["matrix"]).to(dtype=torch.float32),
torch.from_numpy(features["matrix_sum"]).to(dtype=torch.float32),
torch.from_numpy(features["matrix_global"]).to(dtype=torch.float32),
torch.from_numpy(features["matrix_hc"]).to(dtype=torch.float32),
)
targets = (
self.target_retention_times[idx]
if self.target_retention_times is not None
else torch.full_like(
feature_tuples[0], fill_value=float('nan'), dtype=torch.float32
)
)
return feature_tuples, targets


def get_targets(psm_list: PSMList) -> np.ndarray | None:
retention_times = psm_list["retention_time"]
if None not in retention_times:
return torch.tensor(retention_times, dtype=torch.float32)
else:
return None
36 changes: 6 additions & 30 deletions deeplc/feat_extractor.py → deeplc/_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from re import sub

import numpy as np
from psm_utils.psm import Peptidoform
from psm_utils import Peptidoform, PSMList
from pyteomics import mass

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,7 +148,7 @@ def _compute_rolling_sum(matrix: np.ndarray, n: int = 2) -> np.ndarray:

def encode_peptidoform(
peptidoform: Peptidoform | str,
predict_ccs: bool = False,
add_ccs_features: bool = False,
padding_length: int = 60,
positions: set[int] | None = None,
positions_pos: set[int] | None = None,
Expand Down Expand Up @@ -188,7 +188,7 @@ def encode_peptidoform(

matrix_all = np.sum(std_matrix, axis=0)
matrix_all = np.append(matrix_all, seq_len)
if predict_ccs:
if add_ccs_features:
matrix_all = np.append(matrix_all, (seq.count("H")) / seq_len)
matrix_all = np.append(
matrix_all, (seq.count("F") + seq.count("W") + seq.count("Y")) / seq_len
Expand All @@ -198,36 +198,12 @@ def encode_peptidoform(
matrix_all = np.append(matrix_all, charge)

matrix_sum = _compute_rolling_sum(std_matrix.T, n=2)[:, ::2].T

matrix_global = np.concatenate([matrix_all, pos_matrix.flatten()])

return {
"matrix": std_matrix,
"matrix_sum": matrix_sum,
"matrix_all": matrix_all,
"pos_matrix": pos_matrix.flatten(),
"matrix_global": matrix_global,
"matrix_hc": onehot_matrix,
}


def aggregate_encodings(
encodings: list[dict[str, np.ndarray]],
) -> dict[str, dict[int, np.ndarray]]:
"""Aggregate list of encodings into single dictionary."""
return {key: {i: enc[key] for i, enc in enumerate(encodings)} for key in encodings[0]}


def unpack_features(
features: dict[str, np.ndarray],
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Unpack dictionary with features to numpy arrays."""
X_sum = np.stack(list(features["matrix_sum"].values()))
X_global = np.concatenate(
(
np.stack(list(features["matrix_all"].values())),
np.stack(list(features["pos_matrix"].values())),
),
axis=1,
)
X_hc = np.stack(list(features["matrix_hc"].values()))
X_main = np.stack(list(features["matrix"].values()))

return X_sum, X_global, X_hc, X_main
139 changes: 139 additions & 0 deletions deeplc/_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

import copy
import logging

import torch
from torch.utils.data import DataLoader

LOGGER = logging.getLogger(__name__)


class DeepLCFineTuner:
"""
Class for fine-tuning a DeepLC model.

Parameters
----------
model : torch.nn.Module
The model to fine-tune.
train_data : torch.utils.data.Dataset
Dataset containing the training data.
device : str, optional, default='cpu'
The device on which to run the model ('cpu' or 'cuda').
learning_rate : float, optional, default=0.001
The learning rate for the optimizer.
epochs : int, optional, default=10
Number of training epochs.
batch_size : int, optional, default=256
Batch size for training.
validation_data : torch.utils.data.Dataset or None, optional
If provided, used directly for validation. Otherwise, a fraction of
`train_data` will be held out.
validation_split : float, optional, default=0.1
Fraction of `train_data` to reserve for validation when
`validation_data` is None.
patience : int, optional, default=5
Number of epochs with no improvement on validation loss before stopping.
"""

def __init__(
self,
model,
train_data,
device="cpu",
learning_rate=0.001,
epochs=10,
batch_size=256,
validation_data=None,
validation_split=0.1,
patience=5,
):
self.model = model.to(device)
self.train_data = train_data
self.device = device
self.learning_rate = learning_rate
self.epochs = epochs
self.batch_size = batch_size
self.validation_data = validation_data
self.validation_split = validation_split
self.patience = patience

def _freeze_layers(self, unfreeze_keywords="33_1"):
"""
Freezes all layers except those that contain the unfreeze_keyword
in their name.
"""

for name, param in self.model.named_parameters():
param.requires_grad = unfreeze_keywords in name

def prepare_data(self, data, shuffle=True):
return DataLoader(data, batch_size=self.batch_size, shuffle=shuffle)

def fine_tune(self):
LOGGER.debug("Starting fine-tuning...")
if self.validation_data is None:
# Split the training data into training and validation sets
val_size = int(len(self.train_data) * self.validation_split)
train_size = len(self.train_data) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(
self.train_data, [train_size, val_size]
)
else:
train_dataset = self.train_data
val_dataset = self.validation_data
train_loader = self.prepare_data(train_dataset)
val_loader = self.prepare_data(val_dataset, shuffle=False)

optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, self.model.parameters()),
lr=self.learning_rate,
)
loss_fn = torch.nn.L1Loss()
best_model_wts = copy.deepcopy(self.model.state_dict())
best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(self.epochs):
running_loss = 0.0
self.model.train()
for batch in train_loader:
batch_X, batch_X_sum, batch_X_global, batch_X_hc, target = batch

target = target.view(-1, 1)

optimizer.zero_grad()
outputs = self.model(batch_X, batch_X_sum, batch_X_global, batch_X_hc)
loss = loss_fn(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(train_loader)

self.model.eval()
val_loss = 0.0
with torch.no_grad():
for batch in val_loader:
batch_X, batch_X_sum, batch_X_global, batch_X_hc, target = batch
target = target.view(-1, 1)
outputs = self.model(batch_X, batch_X_sum, batch_X_global, batch_X_hc)
val_loss += loss_fn(outputs, target).item()
avg_val_loss = val_loss / len(val_loader)

LOGGER.debug(
f"Epoch {epoch + 1}/{self.epochs}, "
f"Loss: {avg_loss:.4f}, "
f"Validation Loss: {avg_val_loss:.4f}"
)
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_wts = copy.deepcopy(self.model.state_dict())
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= self.patience:
LOGGER.debug(f"Early stopping triggered {epoch + 1}")
break
self.model.load_state_dict(best_model_wts)
return self.model
Loading
Loading