luxlabs header browser preview


@class MPSMatrixMultiplication, NSMutableArray;

@interface MPSRNNMatrixTrainingLayer : MPSKernel {
    int layerType;
    void *layer;
    MPSMatrixMultiplication *gemmKernel;
    MPSMatrixMultiplication *gemmKernel_noAccumulate;
    MPSMatrixMultiplication *gemmKernelNonTranspose;
    MPSMatrixMultiplication *gemmKernelNonTranspose_noAccumulate;
    MPSMatrixMultiplication *gemmKernelTN;
    MPSMatrixMultiplication *gemmKernelTN_accumulate;
    NSMutableArray *weightDescriptors;
}

@property (nonatomic) BOOL propagateFullRecurrentRows;
@property (readonly, nonatomic) unsigned long long inputFeatureChannels;
@property (readonly, nonatomic) unsigned long long outputFeatureChannels;
@property (nonatomic) BOOL storeAllIntermediateStates;
@property (nonatomic) BOOL recurrentOutputIsTemporary;
@property (nonatomic) BOOL trainingStateIsTemporary;
@property (nonatomic) BOOL accumulateWeightGradients;

+ (const struct MPSLibraryInfo { int x0; unsigned int x1; char *x2; void /* function */ *x3; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x4; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x5; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x6; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x7; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x8; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x9; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x10; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x11; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x12; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x13; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x14; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x15; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x16; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x17; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x18; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x19; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x20; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x21; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x22; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x23; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x24; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x25; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x26; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x27; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x28; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x29; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x30; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x31; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x32; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x33; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x34; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x35; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x36; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x37; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x38; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x39; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x40; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x41; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x42; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x43; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x44; struct MPSDeviceSpecificInfo { struct MPSKernelInfo *x0; void /* function */ *x1; unsigned long long x2; } x45; } *)libraryInfo:(void *)a0;

- (id)copyWithZone:(struct _NSZone { } *)a0 device:(id)a1;
- (id)initWithCoder:(id)a0 device:(id)a1;
- (void)dealloc;
- (void)encodeWithCoder:(id)a0;
- (void)createWeightGradientMatrices:(id)a0 dataType:(unsigned int)a1;
- (void)createWeightMatrices:(id)a0;
- (void)encodeCopyWeightsToCommandBuffer:(id)a0 weights:(id)a1 matrixId:(unsigned long long)a2 matrix:(id)a3 copyFromWeightsToMatrix:(BOOL)a4 matrixOffset:(struct { unsigned long long x0; unsigned long long x1; unsigned long long x2; })a5;
- (void)encodeForwardSequenceToCommandBuffer:(id)a0 sourceMatrices:(id)a1 destinationMatrices:(id)a2 trainingStates:(id)a3 weights:(id)a4;
- (void)encodeForwardSequenceToCommandBuffer:(id)a0 sourceMatrices:(id)a1 sourceOffsets:(unsigned long long *)a2 destinationMatrices:(id)a3 destinationOffsets:(unsigned long long *)a4 trainingStates:(id)a5 recurrentInputState:(id)a6 recurrentOutputStates:(id)a7 weights:(id)a8;
- (void)encodeGradientSequenceToCommandBuffer:(id)a0 forwardSources:(id)a1 forwardSourceOffsets:(unsigned long long *)a2 sourceGradients:(id)a3 sourceGradientOffsets:(unsigned long long *)a4 destinationGradients:(id)a5 destinationOffsets:(unsigned long long *)a6 weightGradients:(id)a7 trainingStates:(id)a8 recurrentInputState:(id)a9 recurrentOutputStates:(id)a10 weights:(id)a11;
- (void)encodeGradientSequenceToCommandBuffer:(id)a0 forwardSources:(id)a1 sourceGradients:(id)a2 destinationGradients:(id)a3 weightGradients:(id)a4 trainingStates:(id)a5 weights:(id)a6;
- (id)initWithDevice:(id)a0 rnnDescriptor:(id)a1 trainableWeights:(id)a2;
- (id)temporaryRecurrentStateForCommandBuffer:(id)a0 batchSize:(unsigned long long)a1 forGradientPass:(BOOL)a2;
- (void)createTemporaryWeightGradientMatrices:(id)a0 dataType:(unsigned int)a1 commandBuffer:(id)a2;
- (id)recurrentStateForBatchSize:(unsigned long long)a0 forGradientPass:(BOOL)a1;

@end