Skip to content
Snippets Groups Projects
Commit 5512fa99 authored by Bastien Batardière's avatar Bastien Batardière
Browse files

remove useless tests.

parent 6eaabd2a
No related branches found
No related tags found
1 merge request!15merged dmatrix branch that uses dmatrix to initialize each model.
import os
from pyPLNmodels.models import PLN, PLNPCA, _PLNPCA
from pyPLNmodels import get_simulated_count_data, get_real_count_data
import pytest
from pytest_lazyfixture import lazy_fixture as lf
import pandas as pd
import numpy as np
(
counts_sim,
covariates_sim,
offsets_sim,
) = get_simulated_count_data(nb_cov=2)
couts_real = get_real_count_data(n_samples=298, dim=101)
RANKS = [2, 8]
@pytest.fixture
def instance_plnpca():
plnpca = PLNPCA(ranks=RANKS)
return plnpca
@pytest.fixture
def instance__plnpca():
model = _PLNPCA(rank=RANKS[0])
return model
@pytest.fixture
def instance_pln_full():
return PLN()
all_instances = [lf("instance_plnpca"), lf("instance__plnpca"), lf("instance_pln_full")]
@pytest.mark.parametrize("instance", all_instances)
def test_pandas_init(instance):
instance.fit(
pd.DataFrame(counts_sim.numpy()),
pd.DataFrame(covariates_sim.numpy()),
pd.DataFrame(offsets_sim.numpy()),
)
@pytest.mark.parametrize("instance", all_instances)
def test_numpy_init(instance):
instance.fit(counts_sim.numpy(), covariates_sim.numpy(), offsets_sim.numpy())
@pytest.mark.parametrize("sim_pln", simulated_any_pln)
def test_only_counts(sim_pln):
sim_pln.fit()
@pytest.mark.parametrize("sim_pln", simulated_any_pln)
def test_only_counts_and_offsets(sim_pln):
sim_pln.fit(counts=counts_sim, offsets=offsets_sim)
@pytest.mark.parametrize("sim_pln", simulated_any_pln)
def test_only_Y_and_cov(sim_pln):
sim_pln.fit(counts=counts_sim, covariates=covariates_sim)
......@@ -119,6 +119,7 @@ def test_fail_count_setter(pln):
def test_setter_with_numpy(pln):
np_counts = pln.counts.numpy()
pln.counts = np_counts
pln.fit()
@pytest.mark.parametrize("pln", dict_fixtures["all_pln"])
......@@ -126,6 +127,7 @@ def test_setter_with_numpy(pln):
def test_setter_with_pandas(pln):
pd_counts = pd.DataFrame(pln.counts.numpy())
pln.counts = pd_counts
pln.fit()
@pytest.mark.parametrize("instance", dict_fixtures["instances"])
......
......@@ -2,8 +2,3 @@ import torch
from import_fixtures_and_data import get_dict_fixtures
from pyPLNmodels import PLN
df = get_dict_fixtures(PLN)
for key, fixture in df.items():
print(len(fixture))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment