--- /dev/null
+import torch
+import torch.nn as nn
+
+class LinearRegressionModel(nn.Module):
+ """ A plain linear layer: y = xW^T + b"""
+ def __init__(self, n_features: int, bias: bool = True) -> None:
+ super(LinearRegressionModel, self).__init__()
+ self.n_features = n_features
+ self.weights = nn.Parameter(torch.empty((1, n_features), requires_grad=True))
+ if bias:
+ self.bias = nn.Parameter(torch.empty((1), requires_grad=True))
+ else:
+ self.register_parameter("bias", None)
+ self.initialize_parameters()
+
+ def initialize_parameters(self) -> None:
+ nn.init.kaiming_uniform_(self.weights)
+ if self.bias is not None:
+ nn.init.uniform_(self.bias)
+
+ def forward(self, x) -> torch.Tensor:
+ return x @ self.weights.T + self.bias
+
+
+def main() -> None:
+ n_features = 3
+ batch_size = 5
+ mock_data = torch.randn((batch_size, n_features))
+ model = LinearRegressionModel(n_features)
+ model(mock_data)
+
+
+if __name__ == "__main__":
+ main()