scripts/last-train
author Martin C. Frith
Fri Jun 02 18:40:29 2017 +0900 (2017-06-02)
changeset 863 6a4915d5b5cb
parent 848 7c09e0d848f2
child 867 f9ec9c71d72e
permissions -rwxr-xr-x
last-dotplot: get bp-per-pixel faster
     1 #! /usr/bin/env python
     2 # Copyright 2015 Martin C. Frith
     3 
     4 import math, optparse, os, random, signal, subprocess, sys, tempfile
     5 
     6 def randomSample(things, sampleSize):
     7     """Randomly get sampleSize things (or all if fewer)."""
     8     reservoir = []  # "reservoir sampling" algorithm
     9     for i, x in enumerate(things):
    10         if i < sampleSize:
    11             reservoir.append(x)
    12         else:
    13             r = random.randrange(i + 1)
    14             if r < sampleSize:
    15                 reservoir[r] = x
    16     return reservoir
    17 
    18 def writeWords(outFile, words):
    19     outFile.write(" ".join(words) + "\n")
    20 
    21 def seqInput(fileNames):
    22     for name in fileNames:
    23         with open(name) as file:
    24             seqType = 0
    25             for line in file:
    26                 if seqType == 0:
    27                     if line[0] == ">":
    28                         seqType = 1
    29                         seq = []
    30                     elif line[0] == "@":
    31                         seqType = 2
    32                         lineType = 1
    33                 elif seqType == 1:  # fasta
    34                     if line[0] == ">":
    35                         yield "".join(seq), ""
    36                         seq = []
    37                     else:
    38                         seq.append(line.rstrip())
    39                 elif seqType == 2:  # fastq
    40                     if lineType == 1:
    41                         seq = line.rstrip()
    42                     elif lineType == 3:
    43                         yield seq, line.rstrip()
    44                     lineType = (lineType + 1) % 4
    45             if seqType == 1: yield "".join(seq), ""
    46 
    47 def isGoodChunk(chunk):
    48     for i in chunk:
    49         for j in i[3]:
    50             if j not in "Nn":
    51                 return True
    52     return False
    53 
    54 def chunkInput(opts, sequences):
    55     chunkCount = 0
    56     chunk = []
    57     wantedLength = opts.sample_length
    58     for i, x in enumerate(sequences):
    59         seq, qual = x
    60         if all(i in "Nn" for i in seq): continue
    61         seqLength = len(seq)
    62         beg = 0
    63         while beg < seqLength:
    64             length = min(wantedLength, seqLength - beg)
    65             end = beg + length
    66             segment = i, beg, end, seq[beg:end], qual[beg:end]
    67             chunk.append(segment)
    68             wantedLength -= length
    69             if not wantedLength:
    70                 if isGoodChunk(chunk):
    71                     yield chunk
    72                     chunkCount += 1
    73                 chunk = []
    74                 wantedLength = opts.sample_length
    75             beg = end
    76     if chunk and chunkCount < opts.sample_number:
    77         yield chunk
    78 
    79 def writeSegment(outfile, segment):
    80     if not segment: return
    81     i, beg, end, seq, qual = segment
    82     name = str(i) + ":" + str(beg)
    83     if qual:
    84         outfile.write("@" + name + "\n")
    85         outfile.write(seq)
    86         outfile.write("\n+\n")
    87         outfile.write(qual)
    88     else:
    89         outfile.write(">" + name + "\n")
    90         outfile.write(seq)
    91     outfile.write("\n")
    92 
    93 def getSeqSample(opts, queryFiles, outfile):
    94     sequences = seqInput(queryFiles)
    95     chunks = chunkInput(opts, sequences)
    96     sample = randomSample(chunks, opts.sample_number)
    97     sample.sort()
    98     x = None
    99     for chunk in sample:
   100         for y in chunk:
   101             if x and y[0] == x[0] and y[1] == x[2]:
   102                 x = x[0], x[1], y[2], x[3] + y[3], x[4] + y[4]
   103             else:
   104                 writeSegment(outfile, x)
   105                 x = y
   106     writeSegment(outfile, x)
   107 
   108 def scaleFromHeader(lines):
   109     for line in lines:
   110         for i in line.split():
   111             if i.startswith("t="):
   112                 return float(i[2:])
   113     raise Exception("couldn't read the scale")
   114 
   115 def scoreMatrixFromHeader(lines):
   116     matrix = []
   117     for line in lines:
   118         w = line.split()
   119         if len(w) > 2 and len(w[1]) == 1:
   120             matrix.append(w[1:])
   121         elif matrix:
   122             break
   123     return matrix
   124 
   125 def scaledMatrix(matrix, scaleIncrease):
   126     return matrix[0:1] + [i[0:1] + [int(j) * scaleIncrease for j in i[1:]]
   127                           for i in matrix[1:]]
   128 
   129 def countsFromLastOutput(lines, opts):
   130     matrix = []
   131     # use +1 pseudocounts as a kludge to mitigate numerical problems:
   132     matches = 1.0
   133     deletes = 2.0  # 1 open + 1 extension
   134     inserts = 2.0  # 1 open + 1 extension
   135     delOpens = 1.0
   136     insOpens = 1.0
   137     alignments = 0  # no pseudocount here
   138     for line in lines:
   139         if line[0] == "s":
   140             strand = line.split()[4]  # slow?
   141         if line[0] == "c":
   142             c = map(float, line.split()[1:])
   143             if not matrix:
   144                 matrixSize = int(math.sqrt(len(c) - 10))
   145                 matrix = [[1.0] * matrixSize for i in range(matrixSize)]
   146             identities = sum(c[i * matrixSize + i] for i in range(matrixSize))
   147             alignmentLength = c[-10] + c[-9] + c[-8]
   148             if 100 * identities > opts.pid * alignmentLength: continue
   149             for i in range(matrixSize):
   150                 for j in range(matrixSize):
   151                     if strand == "+" or opts.S == "0":
   152                         matrix[i][j]       += c[i * matrixSize + j]
   153                     else:
   154                         matrix[-1-i][-1-j] += c[i * matrixSize + j]
   155             matches += c[-10]
   156             deletes += c[-9]
   157             inserts += c[-8]
   158             delOpens += c[-7]
   159             insOpens += c[-6]
   160             alignments += 1
   161     gapCounts = matches, deletes, inserts, delOpens, insOpens, alignments
   162     return matrix, gapCounts
   163 
   164 def scoreFromProb(scale, prob):
   165     if prob > 0: logProb = math.log(prob)
   166     else:        logProb = -800  # exp(-800) is exactly zero, on my computer
   167     return int(round(scale * logProb))
   168 
   169 def costFromProb(scale, prob):
   170     return -scoreFromProb(scale, prob)
   171 
   172 def guessAlphabet(matrixSize):
   173     if matrixSize ==  4: return "ACGT"
   174     if matrixSize == 20: return "ACDEFGHIKLMNPQRSTVWY"
   175     raise Exception("can't handle unusual alphabets")
   176 
   177 def matrixWithLetters(matrix):
   178     alphabet = guessAlphabet(len(matrix))
   179     return [alphabet] + [[a] + i for a, i in zip(alphabet, matrix)]
   180 
   181 def writeMatrixHead(outFile, prefix, alphabet, formatString):
   182     writeWords(outFile, [prefix + " "] + [formatString % k for k in alphabet])
   183 
   184 def writeMatrixBody(outFile, prefix, alphabet, matrix, formatString):
   185     for i, j in zip(alphabet, matrix):
   186         writeWords(outFile, [prefix + i] + [formatString % k for k in j])
   187 
   188 def writeCountMatrix(outFile, matrix, prefix):
   189     alphabet = guessAlphabet(len(matrix))
   190     writeMatrixHead(outFile, prefix, alphabet, "%-14s")
   191     writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14s")
   192 
   193 def writeProbMatrix(outFile, matrix, prefix):
   194     alphabet = guessAlphabet(len(matrix))
   195     writeMatrixHead(outFile, prefix, alphabet, "%-14s")
   196     writeMatrixBody(outFile, prefix, alphabet, matrix, "%-14g")
   197 
   198 def writeScoreMatrix(outFile, matrix, prefix):
   199     alphabet = guessAlphabet(len(matrix))
   200     writeMatrixHead(outFile, prefix, alphabet, "%6s")
   201     writeMatrixBody(outFile, prefix, alphabet, matrix, "%6s")
   202 
   203 def writeMatrixWithLetters(outFile, matrix, prefix):
   204     head = matrix[0]
   205     tail = matrix[1:]
   206     left = [i[0] for i in tail]
   207     body = [i[1:] for i in tail]
   208     writeMatrixHead(outFile, prefix, head, "%6s")
   209     writeMatrixBody(outFile, prefix, left, body, "%6s")
   210 
   211 def matProbsFromCounts(counts, opts):
   212     r = range(len(counts))
   213     if opts.revsym:  # add complement (reverse strand) substitutions
   214         counts = [[counts[i][j] + counts[-1-i][-1-j] for j in r] for i in r]
   215     if opts.matsym:  # symmetrize the substitution matrix
   216         counts = [[counts[i][j] + counts[j][i] for j in r] for i in r]
   217     identities = sum(counts[i][i] for i in r)
   218     total = sum(map(sum, counts))
   219     probs = [[j / total for j in i] for i in counts]
   220 
   221     print "# substitution percent identity: %g" % (100 * identities / total)
   222     print
   223     print "# count matrix (query letters = columns, reference letters = rows):"
   224     writeCountMatrix(sys.stdout, counts, "# ")
   225     print
   226     print "# probability matrix (query letters = columns, reference letters = rows):"
   227     writeProbMatrix(sys.stdout, probs, "# ")
   228     print
   229 
   230     return probs
   231 
   232 def gapProbsFromCounts(counts, opts):
   233     matches, deletes, inserts, delOpens, insOpens, alignments = counts
   234     if not alignments: raise Exception("no alignments")
   235     gaps = deletes + inserts
   236     gapOpens = delOpens + insOpens
   237     denominator = matches + gapOpens + (alignments + 1)  # +1 pseudocount
   238     if opts.gapsym:
   239         delOpenProb = gapOpens / denominator / 2
   240         insOpenProb = gapOpens / denominator / 2
   241         delExtendProb = (gaps - gapOpens) / gaps
   242         insExtendProb = (gaps - gapOpens) / gaps
   243     else:
   244         delOpenProb = delOpens / denominator
   245         insOpenProb = insOpens / denominator
   246         delExtendProb = (deletes - delOpens) / deletes
   247         insExtendProb = (inserts - insOpens) / inserts
   248 
   249     print "# aligned letter pairs:", matches
   250     print "# deletes:", deletes
   251     print "# inserts:", inserts
   252     print "# delOpens:", delOpens
   253     print "# insOpens:", insOpens
   254     print "# alignments:", alignments
   255     print "# mean delete size: %g" % (deletes / delOpens)
   256     print "# mean insert size: %g" % (inserts / insOpens)
   257     print "# delOpenProb: %g" % delOpenProb
   258     print "# insOpenProb: %g" % insOpenProb
   259     print "# delExtendProb: %g" % delExtendProb
   260     print "# insExtendProb: %g" % insExtendProb
   261     print
   262 
   263     delCloseProb = 1 - delExtendProb
   264     insCloseProb = 1 - insExtendProb
   265     firstDelProb = delOpenProb * delCloseProb
   266     firstInsProb = insOpenProb * insCloseProb
   267 
   268     # If we define "an alignment" to mean "a set of indistinguishable
   269     # paths", then:
   270     #delExtendProb += firstDelProb
   271     #insExtendProb += firstInsProb
   272     # Else, this ensures gap existence cost >= 0:
   273     delExtendProb = max(delExtendProb, firstDelProb)
   274     insExtendProb = max(insExtendProb, firstInsProb)
   275 
   276     delExistProb = firstDelProb / delExtendProb
   277     insExistProb = firstInsProb / insExtendProb
   278 
   279     return delExistProb, insExistProb, delExtendProb, insExtendProb
   280 
   281 def scoreFromLetterProbs(scale, pairProb, prob1, prob2):
   282     probRatio = pairProb / (prob1 * prob2)
   283     return scoreFromProb(scale, probRatio)
   284 
   285 def matScoresFromProbs(scale, probs):
   286     rowProbs = map(sum, probs)
   287     colProbs = map(sum, zip(*probs))
   288     return [[scoreFromLetterProbs(scale, j, x, y) for j, y in zip(i, colProbs)]
   289             for i, x in zip(probs, rowProbs)]
   290 
   291 def gapCostsFromProbs(scale, probs):
   292     delExistProb, insExistProb, delExtendProb, insExtendProb = probs
   293     delExistCost = costFromProb(scale, delExistProb)
   294     insExistCost = costFromProb(scale, insExistProb)
   295     delExtendCost = costFromProb(scale, delExtendProb)
   296     insExtendCost = costFromProb(scale, insExtendProb)
   297     if delExtendCost == 0: delExtendCost = 1
   298     if insExtendCost == 0: insExtendCost = 1
   299     return delExistCost, insExistCost, delExtendCost, insExtendCost
   300 
   301 def writeLine(out, *things):
   302     out.write(" ".join(map(str, things)) + "\n")
   303 
   304 def writeGapCosts(gapCosts, out):
   305     delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
   306     writeLine(out, "#last -a", delExistCost)
   307     writeLine(out, "#last -A", insExistCost)
   308     writeLine(out, "#last -b", delExtendCost)
   309     writeLine(out, "#last -B", insExtendCost)
   310 
   311 def printGapCosts(gapCosts):
   312     delExistCost, insExistCost, delExtendCost, insExtendCost = gapCosts
   313     print "# delExistCost:", delExistCost
   314     print "# insExistCost:", insExistCost
   315     print "# delExtendCost:", delExtendCost
   316     print "# insExtendCost:", insExtendCost
   317     print
   318 
   319 def tryToMakeChildProgramsFindable():
   320     myDir = os.path.dirname(__file__)
   321     srcDir = os.path.join(myDir, os.pardir, "src")
   322     # put srcDir first, to avoid getting older versions of LAST:
   323     os.environ["PATH"] = srcDir + os.pathsep + os.environ["PATH"]
   324 
   325 def fixedLastalArgs(opts):
   326     x = ["lastal", "-j7"]
   327     if opts.D: x.append("-D" + opts.D)
   328     if opts.E: x.append("-E" + opts.E)
   329     if opts.s: x.append("-s" + opts.s)
   330     if opts.S: x.append("-S" + opts.S)
   331     if opts.C: x.append("-C" + opts.C)
   332     if opts.T: x.append("-T" + opts.T)
   333     if opts.m: x.append("-m" + opts.m)
   334     if opts.P: x.append("-P" + opts.P)
   335     if opts.Q: x.append("-Q" + opts.Q)
   336     return x
   337 
   338 def doTraining(opts, args):
   339     tryToMakeChildProgramsFindable()
   340     scaleIncrease = 20  # while training, up-scale the scores by this amount
   341     x = fixedLastalArgs(opts)
   342     if opts.r: x.append("-r" + opts.r)
   343     if opts.q: x.append("-q" + opts.q)
   344     if opts.p: x.append("-p" + opts.p)
   345     if opts.a: x.append("-a" + opts.a)
   346     if opts.b: x.append("-b" + opts.b)
   347     if opts.A: x.append("-A" + opts.A)
   348     if opts.B: x.append("-B" + opts.B)
   349     x += args
   350     y = ["last-split", "-n"]
   351     p = subprocess.Popen(x, stdout=subprocess.PIPE)
   352     q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)
   353     externalScale = scaleFromHeader(q.stdout)
   354     internalScale = externalScale * scaleIncrease
   355     if opts.Q:
   356         externalMatrix = scoreMatrixFromHeader(q.stdout)
   357         internalMatrix = scaledMatrix(externalMatrix, scaleIncrease)
   358     oldParameters = []
   359 
   360     print "# maximum percent identity:", opts.pid
   361     print "# scale of score parameters:", externalScale
   362     print "# scale used while training:", internalScale
   363     print
   364 
   365     while True:
   366         print "#", " ".join(x)
   367         print
   368         sys.stdout.flush()
   369         matCounts, gapCounts = countsFromLastOutput(q.stdout, opts)
   370         gapProbs = gapProbsFromCounts(gapCounts, opts)
   371         gapCosts = gapCostsFromProbs(internalScale, gapProbs)
   372         printGapCosts(gapCosts)
   373         if opts.Q:
   374             if gapCosts in oldParameters: break
   375             oldParameters.append(gapCosts)
   376         else:
   377             matProbs = matProbsFromCounts(matCounts, opts)
   378             matScores = matScoresFromProbs(internalScale, matProbs)
   379             print "# score matrix (query letters = columns, reference letters = rows):"
   380             writeScoreMatrix(sys.stdout, matScores, "# ")
   381             print
   382             parameters = gapCosts, matScores
   383             if parameters in oldParameters: break
   384             oldParameters.append(parameters)
   385             internalMatrix = matrixWithLetters(matScores)
   386         x = fixedLastalArgs(opts)
   387         x.append("-p-")
   388         x += args
   389         p = subprocess.Popen(x, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
   390         writeGapCosts(gapCosts, p.stdin)
   391         writeMatrixWithLetters(p.stdin, internalMatrix, "")
   392         p.stdin.close()
   393         # in python2.6, the next line must come after p.stdin.close()
   394         q = subprocess.Popen(y, stdin=p.stdout, stdout=subprocess.PIPE)
   395 
   396     gapCosts = gapCostsFromProbs(externalScale, gapProbs)
   397     writeGapCosts(gapCosts, sys.stdout)
   398     if opts.s: writeLine(sys.stdout, "#last -s", opts.s)
   399     if opts.S: writeLine(sys.stdout, "#last -S", opts.S)
   400     if not opts.Q:
   401         matScores = matScoresFromProbs(externalScale, matProbs)
   402         externalMatrix = matrixWithLetters(matScores)
   403     print "# score matrix (query letters = columns, reference letters = rows):"
   404     writeMatrixWithLetters(sys.stdout, externalMatrix, "")
   405 
   406 def lastTrain(opts, args):
   407     if opts.sample_number:
   408         random.seed(math.pi)
   409         refName = args[0]
   410         queryFiles = args[1:]
   411         try:
   412             with tempfile.NamedTemporaryFile(delete=False) as f:
   413                 getSeqSample(opts, queryFiles, f)
   414             doTraining(opts, [refName, f.name])
   415         finally:
   416             os.remove(f.name)
   417     else:
   418         doTraining(opts, args)
   419 
   420 if __name__ == "__main__":
   421     signal.signal(signal.SIGPIPE, signal.SIG_DFL)  # avoid silly error message
   422     usage = "%prog [options] lastdb-name sequence-file(s)"
   423     description = "Try to find suitable score parameters for aligning the given sequences."
   424     op = optparse.OptionParser(usage=usage, description=description)
   425     og = optparse.OptionGroup(op, "Training options")
   426     og.add_option("--revsym", action="store_true",
   427                   help="force reverse-complement symmetry")
   428     og.add_option("--matsym", action="store_true",
   429                   help="force symmetric substitution matrix")
   430     og.add_option("--gapsym", action="store_true",
   431                   help="force insertion/deletion symmetry")
   432     og.add_option("--pid", type="float", default=100, help=
   433                   "skip alignments with > PID% identity (default: %default)")
   434     og.add_option("--sample-number", type="int", default=500, metavar="N",
   435                   help="number of random sequence samples (default: %default)")
   436     og.add_option("--sample-length", type="int", default=2000, metavar="L",
   437                   help="length of each sample (default: %default)")
   438     op.add_option_group(og)
   439     og = optparse.OptionGroup(op, "Initial parameter options")
   440     og.add_option("-r", metavar="SCORE",
   441                   help="match score (default: 6 if Q>0, else 5)")
   442     og.add_option("-q", metavar="COST",
   443                   help="mismatch cost (default: 18 if Q>0, else 5)")
   444     og.add_option("-p", metavar="NAME", help="match/mismatch score matrix")
   445     og.add_option("-a", metavar="COST",
   446                   help="gap existence cost (default: 21 if Q>0, else 15)")
   447     og.add_option("-b", metavar="COST",
   448                   help="gap extension cost (default: 9 if Q>0, else 3)")
   449     og.add_option("-A", metavar="COST", help="insertion existence cost")
   450     og.add_option("-B", metavar="COST", help="insertion extension cost")
   451     op.add_option_group(og)
   452     og = optparse.OptionGroup(op, "Alignment options")
   453     og.add_option("-D", metavar="LENGTH",
   454                   help="query letters per random alignment (default: 1e6)")
   455     og.add_option("-E", metavar="EG2",
   456                   help="maximum expected alignments per square giga")
   457     og.add_option("-s", metavar="STRAND", help=
   458                   "0=reverse, 1=forward, 2=both (default: 2 if DNA, else 1)")
   459     og.add_option("-S", metavar="NUMBER", default="1", help=
   460                   "score matrix applies to forward strand of: " +
   461                   "0=reference, 1=query (default: %default)")
   462     og.add_option("-C", metavar="COUNT", help=
   463                   "omit gapless alignments in COUNT others with > score-per-length")
   464     og.add_option("-T", metavar="NUMBER",
   465                   help="type of alignment: 0=local, 1=overlap (default: 0)")
   466     og.add_option("-m", metavar="COUNT", help=
   467                   "maximum initial matches per query position (default: 10)")
   468     og.add_option("-P", metavar="THREADS",
   469                   help="number of parallel threads")
   470     og.add_option("-Q", metavar="NUMBER",
   471                   help="input format: 0=fasta, 1=fastq-sanger")
   472     op.add_option_group(og)
   473     (opts, args) = op.parse_args()
   474     if len(args) < 2: op.error("I need a lastdb index and query sequences")
   475     if not opts.p and (not opts.Q or opts.Q == "0"):
   476         if not opts.r: opts.r = "5"
   477         if not opts.q: opts.q = "5"
   478         if not opts.a: opts.a = "15"
   479         if not opts.b: opts.b = "3"
   480 
   481     try: lastTrain(opts, args)
   482     except KeyboardInterrupt: pass  # avoid silly error message
   483     except Exception, e:
   484         prog = os.path.basename(sys.argv[0])
   485         sys.exit(prog + ": error: " + str(e))