PRSFNN

logo

PRSFNN (Polygenic Risk Score with Functional Neural Networks) is a Julia module for calculating polygenic risk scores by integrating GWAS summary statistics with functional annotations using neural networks.

Features

  • Integration of GWAS summary statistics with functional annotations
  • Linkage disequilibrium (LD) calculation and correction
  • Coordinate Ascent Variational Inference (CAVI) for posterior effect size estimation
  • Neural network models to learn the relationship between functional annotations and genetic effect sizes

PRSFNN runner

Note: this Julia module is used for running PRSFNN in individual LD blocks and training the neural network. To run PRSFNN genome-wide is an HPC environment, see our workflow here.

Installation

# From the Julia REPL
using Pkg
Pkg.add(url="https://github.com/weinstockj/PRS.jl")

Getting Started

Steps to load this module from the root directory:

  1. Run julia --color=yes --project=. (requires Julia 1.9.0 or later)
  2. Run using Revise # helpful while developing
  3. Run using PRSFNN

Now the functions have been loaded. To call an internal function, use PRSFNN.function_name

To run Julia in debugger mode: JULIA_DEBUG=PRSFNN julia --color=yes --project=.

Usage

Basic Example

using PRSFNN

# Run PRSFNN on a genomic region
result = main(
    output_prefix = "chr3_block1",
    annot_data_path = "path/to/annotations.parquet", 
    gwas_data_path = "path/to/gwas_stats.tsv",
    ld_panel_path = "path/to/ld_panel"
)

Simulating GWAS Data

For testing and development, you can simulate GWAS summary statistic data:

# Generate simulated data
raw = simulate_raw(;N = 10_000, P = 1_000, K = 100, h2 = 0.10)
# Extract sufficient statistics needed for PRS
summary_stats = estimate_sufficient_statistics(raw.X, raw.Y)

Training a PRS Model

# Train the PRS model
N = size(raw.X, 1)  # Number of samples
P = size(raw.X, 2)  # Number of SNPs
XtX = construct_XtX(summary_stats.R, ones(P), N)
D = construct_D(XtX)
Xty = construct_Xty(summary_stats.coef, D)

σ2, R2, yty = infer_σ2(
    summary_stats.coef, 
    summary_stats.SE, 
    XtX, 
    Xty, 
    N, 
    P; 
    estimate = true, 
    λ = 0.50 * N
)

K = size(raw.G, 2)
H = 5  # Number of hidden units

layer_1 = Dense(K => H, Flux.softplus; init = Flux.glorot_normal(gain = 0.005))
layer_output = Dense(H => 2)
layer_output.bias .= [StatsFuns.log(0.001), StatsFuns.logit(0.1)]
model = Chain(layer_1, layer_output)
initial_lr = 0.00005
optim_type = AdamW(initial_lr)
opt = Flux.setup(optim_type, model)

PRS = train_until_convergence(
    summary_stats.coef,
    summary_stats.SE,
    summary_stats.R, 
    XtX,
    Xty,
    raw.G,
    model = model,
    opt = opt,
    σ2 = σ2,
    R2 = R2,
    yty = yty,
    N = fill(N, P),  # Vector of sample sizes
    train_nn = false,
    max_iter = 5
)

Documentation

For more detailed documentation, visit the official documentation site.

Running the Tests

Run unit tests with:

julia --project=. test/runtests.jl

or interactively with:

includet("test/runtests.jl")

Contact

Please address correspondence to:

  • Josh Weinstock <josh.weinstock@emory.edu>
  • April Kim <aprilkim@jhu.edu>
  • Alexis Battle <ajbattle@jhu.edu>

Functions

PRSFNN.mainFunction

PRSFNN

This function defines the command line interface for the PRSFNN package.

Arguments

  • output_prefix: A prefix for the output files
  • annot_data_path: A path to the directory containing the annotations
  • ld_panel_path: A path to the directory containing the LD reference panel
  • gwas_data_path: A path to the GWAS summary statistics file
  • model_file: A path to the file containing the trained model
  • betas_output_file: A path to the file where the PRS betas will be saved
  • interpretation_output_file: A path to the file where the interpretation of the model will be saved
source
PRSFNN.rssFunction
rss(β, coef, Σ, SRSinv, to)

Calculate the summary statistic RSS likelihood.

Arguments

  • β::Vector: Vector of effect sizes.
  • coef::Vector: Observed coefficients.
  • Σ::AbstractPDMat: Positive definite covariance matrix.
  • SRSinv::Matrix: Precomputed matrix for efficiency.
  • to: Timer object for benchmarking.

Returns

  • Float64: Log likelihood value.

Example

rss(
    [0.0011, 0.0052, 0.0013],
    [-0.019, 0.013, -0.0199],
    PDMat(Σ),
    SRSinv,
    TimerOutput()
)
source
PRSFNN.elboFunction
elbo(z, q_μ, log_q_var, coef, SE, R, σ2_β, p_causal, σ2, [spike_σ2], to)

Compute the Evidence Lower Bound (ELBO) for variational inference.

Arguments

  • z::Vector: Random vector from standard normal distribution.
  • q_μ::Vector: Mean vector of the variational distribution.
  • log_q_var::Vector: Log variance vector of the variational distribution.
  • coef::Vector: Observed coefficients.
  • SE::Vector or Σ::AbstractPDMat: Standard errors or covariance matrix.
  • R::AbstractArray or SRSinv::Matrix: Correlation matrix or precomputed matrix.
  • σ2_β::Vector: Vector of prior variances for effect sizes.
  • p_causal::Vector: Vector of prior probabilities that each SNP is causal.
  • σ2::Real: Global variance parameter.
  • spike_σ2::Real: Optional variance parameter for the spike component.
  • to: Timer object for benchmarking.

Returns

  • Float64: Computed ELBO value.

Example

elbo(
    rand(Normal(0, 1), 3),
    [0.01, -0.003, 0.0018],
    [-9.234, -9.24, -9.24],
    [0.023, -0.0009, -0.0018],
    [0.0094, 0.00988, 0.0102],
    [1.0 0.03 0.017; 0.031 1.0 -0.03; 0.017 -0.02 1.0],
    [0.01, 0.01, 0.01],
    [0.10, 0.10, 0.10],
    0.01,
    TimerOutput()
)
source
PRSFNN.joint_log_probFunction
joint_log_prob(β, coef, SE, R, σ2_β, p_causal, σ2, [to])

Compute the joint log probability of the model combining likelihood and prior.

Arguments

  • β::Vector: Vector of effect sizes.
  • coef::Vector: Observed coefficients.
  • SE::Vector or Σ::AbstractPDMat: Standard errors or covariance matrix.
  • R::Matrix or SRSinv::Matrix: Correlation matrix or precomputed matrix.
  • σ2_β::Vector: Vector of prior variances for effect sizes.
  • p_causal::Vector: Vector of prior probabilities that each SNP is causal.
  • σ2::Real: Global variance parameter.
  • spike_σ2::Real: Optional variance parameter for the spike component.
  • to: Timer object for benchmarking.

Returns

  • Float64: Joint log probability.

Example

joint_log_prob(
    [0.0011, 0.0052, 0.0013],
    [-0.019, 0.013, -0.0199],
    [0.0098, 0.0098, 0.0102],
    [1.0 0.03 0.017; 0.031 1.0 -0.03; 0.017 -0.02 1.0],
    [0.01, 0.01, 0.01],
    [0.10, 0.10, 0.10],
    0.01,
    TimerOutput()
)
source
PRSFNN.train_until_convergenceFunction
train_until_convergence(coef, SE, R, D, G; max_iter = 20, threshold = 0.1, N = 10_000)

Arguments

  • coef::Vector: A length P vector of effect sizes
  • SE::Vector: A length P vector of standard errors
  • R::AbstractArray: A P x P correlation matrix
  • D::Vector: A length P vector of the sum of squared genotypes
  • G::AbstractArray: A P x K matrix of annotations
source
PRSFNN.fit_heritability_nnFunction

fitheritabilitynn(model, qμ, qμsq, qalpha, G)

Fit the heritability neural network model.

Arguments

  • model::Chain: A neural network model
  • q_μ::Vector: A length P vector of posterior means
  • q_μ_sq::Vector: A length P vector of posterior variances
  • q_α::Vector: A length P vector of posterior probabilities of being causal
  • G::AbstractArray: A P x K matrix of annotations
    model = Chain(
        Dense(20 => 5, relu; init = Flux.glorot_normal(gain = 0.0005)),
        Dense(5 => 2)
    )

    G = rand(Normal(0, 1), 100, 20)
    q_μ_sq = (G * rand(Normal(0, 0.10), 20)) .^ 2
    q_α = 1.0 ./ (1.0 .+ exp.(-1.0 .* (-2.0 .+ q_μ_sq)))
    trained_model = PRSFNN.fit_heritability_nn(
        model, 
        q_μ_sq, 
        q_α, 
        G
    )
    yhat = transpose(trained_model(transpose(G)))
    yhat[:, 1] .= exp.(yhat[:, 1])
    yhat[:, 2] .= 1.0 ./ (1.0 .+ exp.(-yhat[:, 2]))
source
PRSFNN.log_priorFunction
log_prior(β, σ2_β, p_causal, σ2, spike_σ2, to)

Calculate the log density of β based on a spike and slab prior.

Arguments

  • β::Vector: Vector of effect sizes.
  • σ2_β::Vector: Vector of prior variances for effect sizes.
  • p_causal::Vector: Vector of prior probabilities that each SNP is causal.
  • σ2::Real: Global variance parameter.
  • spike_σ2::Real: Variance parameter for the spike component.
  • to: Timer object for benchmarking.

Returns

  • Float64: Log prior density.
source
PRSFNN.estimate_sufficient_statisticsFunction
estimate_sufficient_statistics(X, Y)

Estimate the sufficient statistics for genetic association analysis from genotypes and phenotypes.

Arguments

  • X::AbstractArray: Genotype matrix (N×P) where N is the number of samples and P is the number of SNPs.
  • Y::Vector: Phenotype vector of length N.

Returns

A named tuple containing:

  • coef::Vector: Regression coefficients for each SNP.
  • SE::Vector: Standard errors of the regression coefficients.
  • Z::Vector: Z-scores (coef/SE) for each SNP.
  • R::Matrix: Correlation matrix of the genotypes (LD matrix).
  • D::Vector: Sum of squares for each SNP (diagonal of X'X).

Description

This function computes the basic summary statistics needed for genetic association analysis, including marginal effect sizes, standard errors, test statistics, and the correlation structure between variants.

Example

X = randn(1000, 100)  # 1000 samples, 100 SNPs
Y = randn(1000)       # Phenotypes
stats = estimate_sufficient_statistics(X, Y)
source
PRSFNN.compute_LDFunction
compute_LD(LD_reference::String)

Compute linkage disequilibrium (LD) correlation matrix from a reference genotype file.

Arguments

  • LD_reference::String: Path to the BED format genotype file

Returns

  • R::Matrix{Float64}: LD correlation matrix for polymorphic variants
  • sds::Vector{Float64}: Standard deviations of genotypes for each variant
  • allele_freq::Vector{Float64}: Allele frequencies (mean/2) for each variant
  • good_variants::Vector{Int}: Indices of polymorphic variants that passed QC filters

Details

This function reads genotype data from a BED file, converts it to floating point representation, filters out monomorphic variants and those with very low variance, and computes the correlation matrix between the remaining variants. Variants are filtered based on mean frequency and standard deviation thresholds.

source
PRSFNN.fit_genome_wide_nnFunction

fitgenomewidenn(betas, annotationfilesdir, modelfile; nepochs, H, ntest, learningratedecay, patience)

Train a neural network model to predict variant effects using genome-wide annotation data.

Arguments

  • betas: Path to file containing PRS beta values from previous analysis
  • annotation_files_dir: Directory containing annotation parquet files
  • model_file: Path where the trained model will be saved
  • n_epochs: Number of training epochs (default: 1306)
  • H: Number of hidden units in the neural network (default: 3)
  • n_test: Number of test samples to use for evaluation (default: 30)
  • learning_rate_decay: Learning rate decay factor (default: 0.98)
  • patience: Number of epochs to wait before reducing learning rate (default: 30)

Returns

  • model: The trained neural network model
  • opt: The optimizer state
  • global_σ2: Global residual variance estimate
source
PRSFNN.train_caviFunction
train_cavi(p_causal, σ2_β, coef, SE, R, XtX, Xty, to; P = 1_000, n_elbo = 10, max_iter = 5, N = 10_000, yty = 10_000, spike_σ2 = 1e-6, update_σ2 = true, σ2 = 1.0)

Performs Coordinate Ascent Variational Inference (CAVI) for a spike-and-slab model to estimate genetic variant effect sizes and posterior inclusion probabilities.

Arguments

  • p_causal::Vector{Float64}: Prior probabilities for variants to be causal
  • σ2_β::Vector{Float64}: Prior variances for effect sizes (slab component)
  • coef::Vector{Float64}: Effect sizes from GWAS summary statistics
  • SE::Vector{Float64}: Standard errors of the effect sizes
  • R::Matrix{Float64}: LD correlation matrix
  • XtX::Matrix{Float64}: Matrix equal to N times the covariance matrix of genotypes
  • Xty::Vector{Float64}: Vector of inner products between genotypes and phenotype
  • to::TimerOutput: Timer output object for performance profiling

Keyword Arguments

  • P::Int64=1_000: Number of variants
  • n_elbo::Int64=10: Number of Monte Carlo samples for ELBO estimation
  • max_iter::Int64=5: Maximum number of iterations for convergence
  • N::Float64=10_000: Sample size
  • yty::Float64=10_000: Sum of squared phenotypes
  • spike_σ2::Float64=1e-6: Variance for the spike component
  • update_σ2::Bool=true: Whether to update the residual variance
  • σ2::Float64=1.0: Initial residual variance

Returns

A named tuple containing:

  • q_μ::Vector{Float64}: Posterior means for the slab component
  • q_spike_μ::Vector{Float64}: Posterior means for the spike component
  • q_α::Vector{Float64}: Posterior inclusion probabilities
  • q_var::Vector{Float64}: Posterior variances for the slab component
  • q_odds::Vector{Float64}: Posterior odds for variant inclusion
  • loss::Float64: Final ELBO value
  • se_loss::Float64: Standard error of the ELBO estimate
  • σ2::Float64: Final residual variance estimate

Details

This function implements an iterative Coordinate Ascent Variational Inference algorithm for a spike-and-slab model, commonly used in Polygenic Risk Score (PRS) calculation. At each iteration, it updates the variational parameters to maximize the Evidence Lower Bound (ELBO).

The algorithm stops when insufficient improvement in ELBO is detected or when the maximum number of iterations is reached. The best parameters (with highest ELBO) are returned.

source
PRSFNN.simulate_rawFunction
simulate_raw(; N = 10_000, P = 1_000, K = 100, h2 = 0.10)

Simulate genotype and phenotype data with genetic architecture influenced by functional annotations.

Arguments

  • N::Int=10_000: Number of samples/individuals
  • P::Int=1_000: Number of genetic variants/SNPs
  • K::Int=100: Number of functional annotations
  • h2::Float64=0.10: Desired narrow-sense heritability of the trait

Returns

A named tuple containing:

  • X::Matrix{Float64}: Genotype matrix (N × P)
  • β::Vector{Float64}: True causal effect sizes
  • Y::Vector{Float64}: Simulated phenotypes
  • Σ::Matrix{Float64}: Covariance matrix of genotypes (LD structure)
  • s::Vector{Float64}: Per-SNP normalized variance scalars derived from annotations
  • G::Matrix{Float64}: Matrix of functional annotations (P × K)
  • γ::Vector{Int}: Binary indicators of causal status (1=causal, 0=non-causal)
  • function_choices::Vector{Int}: Selected functional forms for each annotation
  • phi::Vector{Function}: Functions mapping each annotation to effect size variance
  • sigma_squared::Vector{Float64}: Per-SNP variance components before normalization

Details

This function simulates genotype and phenotype data with a realistic genetic architecture:

  1. Genotypes (X) are simulated with a realistic LD structure using random eigendecomposition
  2. 10% of variants are designated as causal (γ)
  3. Functional annotations (G) influence effect size variance through different functional forms
  4. Effect sizes follow a spike-and-slab distribution:
    • Causal variants: Normal(0, scaled_variance)
    • Non-causal variants: Normal(0, tiny_variance)
  5. Phenotypes are computed as Y = Xβ + ε, with noise calibrated to achieve target heritability

The simulation includes both continuous and binary annotations, with varying relationships to effect size variance (linear, quadratic, or null effects).

source
PRSFNN.infer_σ2Function
infer_σ2(coef, SE, R, D, X_sd, N, P)

Arguments

  • coef::Vector: A length P vector of effect sizes
  • SE::Vector: A length P vector of standard errors
  • XtX::AbstractArray: A P x P matrix equal to N times the covariance matrix of the genotypes
  • Xty::Vector: A length P vector of the inner product between genotype and phenotype
  • N: Number of samples
  • P: Number of SNPs
source
PRSFNN.poet_covFunction
poet_cov(X::AbstractArray; K = 100, τ = 0.01, N = 1000)

Apply POET (Principal Orthogonal complEment Thresholding) method to estimate a sparse covariance matrix.

Arguments

  • X::AbstractArray: Input covariance or correlation matrix
  • K::Int=100: Number of principal components to retain
  • τ::Float64=0.01: Thresholding parameter for sparsity
  • N::Int=1000: Sample size used for bias correction

Returns

  • Sigma::Matrix{Float64}: The POET-estimated covariance matrix

Details

Implements the POET method for high-dimensional covariance matrix estimation. The method decomposes the matrix into a low-rank component (representing systematic factors) and a sparse component (representing idiosyncratic noise). The sparse component is obtained by thresholding.

Reference: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5563862/

source